|
| 1 | +.. _writing_dynamo_aten_lowering_passes: |
| 2 | + |
| 3 | +Writing Dynamo ATen Lowering Passes |
| 4 | +=================== |
| 5 | + |
| 6 | +Basics of a Lowering Pass |
| 7 | +------------ |
| 8 | + |
| 9 | +ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object. |
| 10 | + |
| 11 | +Lowering Pass Requirements |
| 12 | +------------ |
| 13 | + |
| 14 | +An ATen lowering pass function in Torch-TRT must satisfy two requirements: |
| 15 | +- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule` |
| 16 | +- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation |
| 17 | + |
| 18 | +See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines. |
| 19 | + |
| 20 | +Example Lowering Pass |
| 21 | +------------ |
| 22 | + |
| 23 | +.. code-block:: python |
| 24 | +
|
| 25 | + def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 26 | + """Repair scenarios where inputs are also outputs of the graph |
| 27 | +
|
| 28 | + TRT does not allow such cases, so we insert a clone (identity) layer |
| 29 | + """ |
| 30 | + modified_graph = False |
| 31 | +
|
| 32 | + # Extract graph placeholder Tensors |
| 33 | + placeholders = [ |
| 34 | + node |
| 35 | + for node in gm.graph.nodes |
| 36 | + if ( |
| 37 | + node.op == "placeholder" |
| 38 | + and isinstance(node.type, type) |
| 39 | + and issubclass(node.type, torch.Tensor) |
| 40 | + ) |
| 41 | + ] |
| 42 | +
|
| 43 | + for placeholder in placeholders: |
| 44 | + # If any placeholder has any users which are direct graph outputs |
| 45 | + if len(placeholder.users) >= 1 and any( |
| 46 | + user.op == "output" for user in placeholder.users |
| 47 | + ): |
| 48 | + modified_graph = True |
| 49 | +
|
| 50 | + # Get direct graph outputs which are direct uses of placeholders |
| 51 | + direct_outputs = [user for user in placeholder.users if user.op == "output"] |
| 52 | +
|
| 53 | + # Insert clone node for placeholder to ensure |
| 54 | + # placeholder is not a direct output |
| 55 | + with gm.graph.inserting_after(placeholder): |
| 56 | + cloned_placeholder = gm.graph.call_function( |
| 57 | + torch.ops.aten.clone.default, |
| 58 | + args=(placeholder,), |
| 59 | + ) |
| 60 | +
|
| 61 | + # Replace placeholder as output with cloned version |
| 62 | + for output in direct_outputs: |
| 63 | + output.replace_input_with(placeholder, cloned_placeholder) |
| 64 | +
|
| 65 | + # If the graph was modified, clean up the graph and ensure it is up-to-date |
| 66 | + if modified_graph: |
| 67 | + gm.graph.eliminate_dead_code() |
| 68 | + gm.graph.lint() |
| 69 | + gm.recompile() |
| 70 | + logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") |
| 71 | +
|
| 72 | + return gm |
| 73 | +
|
| 74 | +
|
| 75 | +Registering Lowering Passes |
| 76 | +---------------------- |
| 77 | + |
| 78 | +Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted. |
| 79 | + |
| 80 | +For instance, to insert the pass at the default location (end of the list), the following code can be used: |
| 81 | + |
| 82 | +.. code-block:: python |
| 83 | +
|
| 84 | + @aten_lowering_pass |
| 85 | + def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 86 | + ... |
| 87 | +
|
| 88 | +Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used: |
| 89 | + |
| 90 | +.. code-block:: python |
| 91 | +
|
| 92 | + @aten_lowering_pass(index=0) |
| 93 | + def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 94 | + ... |
| 95 | +
|
| 96 | +There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index. |
| 97 | + |
| 98 | +.. code-block:: python |
| 99 | +
|
| 100 | + # Print all lowering passes in the list |
| 101 | + print(dump_lowering_passes()) |
| 102 | +
|
| 103 | + # Apply lowering passes to a GraphModule |
| 104 | + apply_lowering_passes(graph_module) |
| 105 | +
|
| 106 | + # Remove the lowering pass at index 1 |
| 107 | + _remove_lowering_pass(index=1) |
| 108 | +
|
| 109 | +**Note:** The above APIs are subject to change, as the lowering pass system evolves. |
0 commit comments