Skip to content

Commit 8ae4510

Browse files
committed
fix: Address review comments and add upgrades
1 parent bf7906c commit 8ae4510

File tree

8 files changed

+155
-139
lines changed

8 files changed

+155
-139
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
.. _writing_dynamo_aten_lowering_passes:
2+
3+
Writing Dynamo ATen Lowering Passes
4+
===================
5+
6+
Basics of a Lowering Pass
7+
------------
8+
9+
ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object.
10+
11+
Lowering Pass Requirements
12+
------------
13+
14+
An ATen lowering pass function in Torch-TRT must satisfy two requirements:
15+
- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule`
16+
- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation
17+
18+
See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines.
19+
20+
Example Lowering Pass
21+
------------
22+
23+
.. code-block:: python
24+
25+
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
26+
"""Repair scenarios where inputs are also outputs of the graph
27+
28+
TRT does not allow such cases, so we insert a clone (identity) layer
29+
"""
30+
modified_graph = False
31+
32+
# Extract graph placeholder Tensors
33+
placeholders = [
34+
node
35+
for node in gm.graph.nodes
36+
if (
37+
node.op == "placeholder"
38+
and isinstance(node.type, type)
39+
and issubclass(node.type, torch.Tensor)
40+
)
41+
]
42+
43+
for placeholder in placeholders:
44+
# If any placeholder has any users which are direct graph outputs
45+
if len(placeholder.users) >= 1 and any(
46+
user.op == "output" for user in placeholder.users
47+
):
48+
modified_graph = True
49+
50+
# Get direct graph outputs which are direct uses of placeholders
51+
direct_outputs = [user for user in placeholder.users if user.op == "output"]
52+
53+
# Insert clone node for placeholder to ensure
54+
# placeholder is not a direct output
55+
with gm.graph.inserting_after(placeholder):
56+
cloned_placeholder = gm.graph.call_function(
57+
torch.ops.aten.clone.default,
58+
args=(placeholder,),
59+
)
60+
61+
# Replace placeholder as output with cloned version
62+
for output in direct_outputs:
63+
output.replace_input_with(placeholder, cloned_placeholder)
64+
65+
# If the graph was modified, clean up the graph and ensure it is up-to-date
66+
if modified_graph:
67+
gm.graph.eliminate_dead_code()
68+
gm.graph.lint()
69+
gm.recompile()
70+
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
71+
72+
return gm
73+
74+
75+
Registering Lowering Passes
76+
----------------------
77+
78+
Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted.
79+
80+
For instance, to insert the pass at the default location (end of the list), the following code can be used:
81+
82+
.. code-block:: python
83+
84+
@aten_lowering_pass
85+
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
86+
...
87+
88+
Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used:
89+
90+
.. code-block:: python
91+
92+
@aten_lowering_pass(index=0)
93+
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
94+
...
95+
96+
There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index.
97+
98+
.. code-block:: python
99+
100+
# Print all lowering passes in the list
101+
print(dump_lowering_passes())
102+
103+
# Apply lowering passes to a GraphModule
104+
apply_lowering_passes(graph_module)
105+
106+
# Remove the lowering pass at index 1
107+
_remove_lowering_pass(index=1)
108+
109+
**Note:** The above APIs are subject to change, as the lowering pass system evolves.

docsrc/index.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ Tutorials
7373
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
7474
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
7575
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
76-
tutorials/_rendered_examples/dynamo/dynamo_aten_lowering_passes
7776

7877
Python API Documenation
7978
------------------------
@@ -129,6 +128,7 @@ Contributor Documentation
129128
--------------------------------
130129
* :ref:`system_overview`
131130
* :ref:`writing_converters`
131+
* :ref:`writing_dynamo_aten_lowering_passes`
132132
* :ref:`useful_links`
133133

134134
.. toctree::
@@ -138,6 +138,7 @@ Contributor Documentation
138138

139139
contributors/system_overview
140140
contributors/writing_converters
141+
contributors/writing_dynamo_aten_lowering_passes
141142
contributors/useful_links
142143

143144
Indices

examples/dynamo/README.rst

-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ a number of ways you can leverage this backend to accelerate inference.
99
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
1010
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
1111
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
12-
:ref:`dynamo_aten_lowering_passes`: Custom modifications of a graph of ATen operators via lowering passes

examples/dynamo/dynamo_aten_lowering_passes.py

-113
This file was deleted.

py/torch_tensorrt/dynamo/aten_tracer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import torch
88
from torch._export import export
9-
from torch_tensorrt.dynamo.backend.backends import constant_fold
10-
from torch_tensorrt.dynamo.lowering import get_decompositions
9+
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
1110
from torch_tensorrt.dynamo.utils import set_log_level
1211

1312
logger = logging.getLogger(__name__)
@@ -29,6 +28,6 @@ def trace(
2928
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
3029
):
3130
graph_module = export(model, tuple(inputs)).module()
32-
constant_fold(graph_module)
31+
graph_module = apply_lowering_passes(graph_module)
3332
logger.debug("Post export graph: " + str(graph_module.graph))
3433
return graph_module

py/torch_tensorrt/dynamo/lowering/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5-
from .passes import add_lowering_pass, apply_lowering_passes
5+
from .passes import apply_lowering_passes, aten_lowering_pass
66
from .substitutions import * # noqa: F401

py/torch_tensorrt/dynamo/lowering/passes/__init__.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
2-
from typing import Callable, Optional
2+
from functools import wraps
3+
from typing import Callable, Optional, Tuple
34

45
import torch
56

@@ -18,22 +19,44 @@
1819
logger = logging.getLogger(__name__)
1920

2021

21-
def add_lowering_pass(
22-
lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule],
22+
LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
23+
24+
25+
def aten_lowering_pass(
26+
*args: LoweringPassSignature,
2327
index: Optional[int] = None,
24-
) -> None:
28+
) -> LoweringPassSignature:
2529
"""Adds a lowering pass to the registry, at a specified index if desired
2630
2731
If no index is specified, the lowering pass is inserted at the end of the list
2832
"""
29-
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
30-
logger.debug(
31-
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
32-
)
33-
return
33+
34+
def add_lowering_pass(
35+
lowering_pass: LoweringPassSignature,
36+
) -> LoweringPassSignature:
37+
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
38+
logger.debug(
39+
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
40+
)
41+
return lowering_pass
42+
43+
# If there are arguments specified, the decorator may have been called as-is
44+
if args:
45+
# The decorator may only be called with the lowering pass
46+
# The index must be specified as a keyword argument
47+
if len(args) == 1 and callable(args[0]):
48+
return add_lowering_pass(args[0])
49+
else:
50+
raise AssertionError(
51+
f"aten_lowering_pass decorator called with invalid arguments {args} "
52+
"To specify an index to insert the pass, use the keyword 'index='"
53+
)
54+
# If no arguments are specified, the decorator was called with an index keyword
55+
else:
56+
return add_lowering_pass
3457

3558

36-
def remove_lowering_pass(index: int) -> None:
59+
def _remove_lowering_pass(*, index: int) -> None:
3760
"""Removes a lowering pass at a specific index from the registry"""
3861
ATEN_LOWERING_PASSES.remove_pass_with_index(index)
3962
logger.debug(

tests/py/dynamo/lowering/test_aten_lowering_passes.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -59,36 +59,34 @@ class TestLoweringPassMembership(TestCase):
5959
def insert_at_end(self):
6060
from torch_tensorrt.dynamo.lowering.passes import (
6161
ATEN_LOWERING_PASSES,
62-
add_lowering_pass,
63-
remove_lowering_pass,
62+
_remove_lowering_pass,
63+
aten_lowering_pass,
6464
)
6565

66+
@aten_lowering_pass
6667
def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
6768
return gm
6869

69-
add_lowering_pass(identity_pass)
70-
7170
self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[-1])
7271

73-
remove_lowering_pass(-1)
72+
_remove_lowering_pass(-1)
7473

7574
self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes)
7675

7776
def insert_at_index(self):
7877
from torch_tensorrt.dynamo.lowering.passes import (
7978
ATEN_LOWERING_PASSES,
80-
add_lowering_pass,
81-
remove_lowering_pass,
79+
_remove_lowering_pass,
80+
aten_lowering_pass,
8281
)
8382

83+
@aten_lowering_pass(index=0)
8484
def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
8585
return gm
8686

87-
add_lowering_pass(identity_pass, 0)
88-
8987
self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[0])
9088

91-
remove_lowering_pass(0)
89+
_remove_lowering_pass(0)
9290

9391
self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes)
9492

0 commit comments

Comments
 (0)