Skip to content

Commit 245ea94

Browse files
committed
feat: Add ATen lowering pass system
1 parent 692921e commit 245ea94

File tree

8 files changed

+188
-39
lines changed

8 files changed

+188
-39
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+4-34
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99
import torch.utils._pytree as pytree
1010
from torch._dynamo.utils import detect_fake_mode
1111
from torch._functorch.aot_autograd import _aot_export_function
12-
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
1312
from torch._ops import OpOverload
1413
from torch_tensorrt.dynamo import CompilationSettings
1514
from torch_tensorrt.dynamo.compile import compile_module
16-
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
15+
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
1716
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1817
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1918

@@ -75,7 +74,7 @@ def _pretraced_backend(
7574
fake_mode, "allow_non_fake_inputs", True
7675
), fake_mode:
7776
# Invoke AOTAutograd to translate operators to aten
78-
graph_module = aot_export_for_compile(
77+
gm = aot_export_for_compile(
7978
gm,
8079
sample_inputs,
8180
decompositions=get_decompositions(
@@ -85,10 +84,10 @@ def _pretraced_backend(
8584

8685
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
8786

88-
constant_fold(graph_module)
87+
gm = apply_lowering_passes(gm)
8988

9089
trt_compiled = compile_module(
91-
graph_module,
90+
gm,
9291
sample_inputs,
9392
settings=settings,
9493
)
@@ -112,35 +111,6 @@ def _pretraced_backend(
112111
raise
113112

114113

115-
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
116-
def constant_fold(gm: torch.fx.GraphModule) -> Any:
117-
"""Adapted from:
118-
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
119-
120-
Folds constants in the graph module, not skipping constructors
121-
122-
Modifies the graph in-place and replaces node with constants
123-
"""
124-
cf = ConstantFolder(gm, skip_constructors=False)
125-
cf.run()
126-
127-
for node, constant in cf.node_replacements.items():
128-
replace_node_with_constant(gm, node, constant)
129-
130-
erased_params = []
131-
for node in gm.graph.nodes:
132-
if node.op == "get_attr" and len(node.users) == 0:
133-
delattr(gm, node.target)
134-
erased_params.append(node)
135-
136-
for node in erased_params:
137-
gm.graph.erase_node(node)
138-
139-
gm.graph.eliminate_dead_code()
140-
gm.graph.lint()
141-
gm.recompile()
142-
143-
144114
def aot_export_for_compile(
145115
func: torch.fx.GraphModule,
146116
args: Sequence[torch.Tensor],

py/torch_tensorrt/dynamo/lowering/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from .passes import add_lowering_pass, apply_lowering_passes
56
from .substitutions import * # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Callable
2+
3+
import torch
4+
from torch.fx.passes.pass_manager import PassManager
5+
6+
from .constant_folding import constant_fold
7+
from .repair_input_as_output import repair_input_as_output
8+
9+
ATEN_LOWERING_PASSES = PassManager.build_from_passlist(
10+
[
11+
constant_fold,
12+
repair_input_as_output,
13+
]
14+
)
15+
16+
17+
def add_lowering_pass(
18+
lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
19+
) -> None:
20+
"""Adds a lowering pass to the registry"""
21+
ATEN_LOWERING_PASSES.add_pass(lowering_pass)
22+
return
23+
24+
25+
def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
26+
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
27+
return ATEN_LOWERING_PASSES(gm)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
3+
import torch
4+
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
10+
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
11+
"""Adapted from:
12+
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
13+
14+
Folds constants in the graph module, not skipping constructors
15+
16+
Modifies the graph in-place and replaces node with constants
17+
"""
18+
cf = ConstantFolder(gm, skip_constructors=False)
19+
cf.run()
20+
21+
for node, constant in cf.node_replacements.items():
22+
replace_node_with_constant(gm, node, constant)
23+
24+
erased_params = []
25+
for node in gm.graph.nodes:
26+
if node.op == "get_attr" and len(node.users) == 0:
27+
delattr(gm, node.target)
28+
erased_params.append(node)
29+
30+
for node in erased_params:
31+
gm.graph.erase_node(node)
32+
33+
gm.graph.eliminate_dead_code()
34+
gm.graph.lint()
35+
gm.recompile()
36+
37+
logger.debug(f"Graph after constant folding:\n{gm.graph}")
38+
39+
return gm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
9+
"""Repair scenarios where inputs are also outputs of the graph
10+
11+
TRT does not allow such cases, so we insert a `clone` (identity) layer
12+
"""
13+
modified_graph = False
14+
15+
# Extract graph placeholders
16+
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
17+
18+
for placeholder in placeholders:
19+
# If any placeholder has any users which are direct graph outputs
20+
if len(placeholder.users) >= 1 and any(
21+
user.op == "output" for user in placeholder.users
22+
):
23+
modified_graph = True
24+
25+
# Get direct graph outputs which are direct uses of placeholders
26+
direct_outputs = [user for user in placeholder.users if user.op == "output"]
27+
28+
# Insert clone node for placeholder to ensure placeholder is not a direct output
29+
with gm.graph.inserting_after(placeholder):
30+
cloned_placeholder = gm.graph.call_function(
31+
torch.ops.aten.clone.default,
32+
args=(placeholder,),
33+
)
34+
35+
# Replace placeholder as output with cloned version
36+
for output in direct_outputs:
37+
output.replace_input_with(placeholder, cloned_placeholder)
38+
39+
if modified_graph:
40+
gm.graph.eliminate_dead_code()
41+
gm.graph.lint()
42+
gm.recompile()
43+
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
44+
45+
return gm

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def run(self):
392392
"torch_tensorrt.dynamo.conversion.impl.unary",
393393
"torch_tensorrt.dynamo.lowering",
394394
"torch_tensorrt.dynamo.lowering.substitutions",
395+
"torch_tensorrt.dynamo.lowering.passes",
395396
"torch_tensorrt.dynamo.partitioning",
396397
"torch_tensorrt.dynamo.runtime",
397398
"torch_tensorrt.dynamo.tools",
@@ -419,6 +420,7 @@ def run(self):
419420
"torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary",
420421
"torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering",
421422
"torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions",
423+
"torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes",
422424
"torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning",
423425
"torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime",
424426
"torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch_tensorrt
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
5+
from ..testing_utilities import lower_graph_testing
6+
7+
8+
class TestInputAsOutput(TestCase):
9+
def test_input_as_output(self):
10+
class InputAsOutput(torch.nn.Module):
11+
def forward(self, x, y):
12+
y_new = y + 1
13+
y_new = y_new * 7
14+
return (y_new, (x, y_new))
15+
16+
# Operations expected to be included in the traced graph after decompositions
17+
expected_ops = {torch.ops.aten.clone.default}
18+
19+
inputs = [
20+
torch.rand(
21+
5,
22+
7,
23+
).cuda(),
24+
torch.rand(
25+
5,
26+
7,
27+
).cuda(),
28+
]
29+
30+
fx_graph = torch.fx.symbolic_trace(InputAsOutput())
31+
_, expected_ops_unseen = lower_graph_testing(
32+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
33+
)
34+
35+
self.assertEquals(
36+
len(expected_ops_unseen),
37+
0,
38+
f"The following expected ops were not encountered: {expected_ops_unseen}",
39+
)
40+
torch._dynamo.reset()
41+
42+
# Validate that the results between Torch and Torch-TRT are similar
43+
optimized_model = torch_tensorrt.compile(
44+
fx_graph,
45+
"torch_compile",
46+
inputs,
47+
min_block_size=1,
48+
pass_through_build_failures=True,
49+
)
50+
optimized_model_results = optimized_model(*inputs).detach().cpu()
51+
torch_model_results = fx_graph(*inputs).detach().cpu()
52+
53+
max_diff = float(
54+
torch.max(torch.abs(optimized_model_results - torch_model_results))
55+
)
56+
self.assertAlmostEqual(
57+
max_diff,
58+
0,
59+
msg=f"InputAsOutput TRT outputs don't match with the original model.",
60+
)
61+
torch._dynamo.reset()
62+
63+
64+
if __name__ == "__main__":
65+
run_tests()

tests/py/dynamo/testing_utilities.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import torch
77
from torch._dynamo.utils import detect_fake_mode
88
from torch_tensorrt.dynamo import partitioning
9-
from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile, constant_fold
10-
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
9+
from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile
10+
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
1111
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1212

1313
DECIMALS_OF_AGREEMENT = 4
@@ -40,16 +40,16 @@ def fx_dynamo_testing_backend(
4040
fake_mode, "allow_non_fake_inputs", True
4141
), fake_mode:
4242
# Invoke AOTAutograd to translate operators to aten
43-
graph_module = aot_export_for_compile(
43+
gm = aot_export_for_compile(
4444
gm,
4545
sample_inputs,
4646
decompositions=get_decompositions(),
4747
)
4848

49-
constant_fold(graph_module)
49+
gm = apply_lowering_passes(gm)
5050

5151
trt_compiled = custom_backend(
52-
graph_module,
52+
gm,
5353
sample_inputs,
5454
)
5555
return trt_compiled

0 commit comments

Comments
 (0)