9
9
import torch .utils ._pytree as pytree
10
10
from torch ._dynamo .utils import detect_fake_mode
11
11
from torch ._functorch .aot_autograd import _aot_export_function
12
- from torch ._inductor .constant_folding import ConstantFolder , replace_node_with_constant
13
12
from torch ._ops import OpOverload
14
13
from torch_tensorrt .dynamo import CompilationSettings
15
14
from torch_tensorrt .dynamo .compile import compile_module
16
- from torch_tensorrt .dynamo .lowering . _decompositions import get_decompositions
15
+ from torch_tensorrt .dynamo .lowering import apply_lowering_passes , get_decompositions
17
16
from torch_tensorrt .dynamo .lowering ._pre_aot_lowering import pre_aot_substitutions
18
17
from torch_tensorrt .dynamo .utils import parse_dynamo_kwargs
19
18
@@ -75,7 +74,7 @@ def _pretraced_backend(
75
74
fake_mode , "allow_non_fake_inputs" , True
76
75
), fake_mode :
77
76
# Invoke AOTAutograd to translate operators to aten
78
- graph_module = aot_export_for_compile (
77
+ gm = aot_export_for_compile (
79
78
gm ,
80
79
sample_inputs ,
81
80
decompositions = get_decompositions (
@@ -85,10 +84,10 @@ def _pretraced_backend(
85
84
86
85
logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
87
86
88
- constant_fold ( graph_module )
87
+ gm = apply_lowering_passes ( gm )
89
88
90
89
trt_compiled = compile_module (
91
- graph_module ,
90
+ gm ,
92
91
sample_inputs ,
93
92
settings = settings ,
94
93
)
@@ -112,35 +111,6 @@ def _pretraced_backend(
112
111
raise
113
112
114
113
115
- @torch .utils ._python_dispatch ._disable_current_modes () # type: ignore
116
- def constant_fold (gm : torch .fx .GraphModule ) -> Any :
117
- """Adapted from:
118
- https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
119
-
120
- Folds constants in the graph module, not skipping constructors
121
-
122
- Modifies the graph in-place and replaces node with constants
123
- """
124
- cf = ConstantFolder (gm , skip_constructors = False )
125
- cf .run ()
126
-
127
- for node , constant in cf .node_replacements .items ():
128
- replace_node_with_constant (gm , node , constant )
129
-
130
- erased_params = []
131
- for node in gm .graph .nodes :
132
- if node .op == "get_attr" and len (node .users ) == 0 :
133
- delattr (gm , node .target )
134
- erased_params .append (node )
135
-
136
- for node in erased_params :
137
- gm .graph .erase_node (node )
138
-
139
- gm .graph .eliminate_dead_code ()
140
- gm .graph .lint ()
141
- gm .recompile ()
142
-
143
-
144
114
def aot_export_for_compile (
145
115
func : torch .fx .GraphModule ,
146
116
args : Sequence [torch .Tensor ],
0 commit comments