19
19
import torch
20
20
from torch ._dynamo .convert_frame import compile_lock
21
21
from torch ._inductor .decomposition import select_decomp_table
22
+ from torch .fx ._lazy_graph_module import _LazyGraphModule
22
23
from torch .fx .experimental import proxy_tensor
23
24
from torch .fx .traceback import preserve_node_meta
24
25
from torch .utils import _pytree as pytree
@@ -66,7 +67,7 @@ class _TLS(Protocol):
66
67
tls : _TLS = cast ("_TLS" , threading .local ())
67
68
68
69
69
- def _make_fx (fn : Callable [..., object ], * args : object ) -> torch .fx .GraphModule :
70
+ def _make_fx (fn : Callable [..., object ], * args : object ) -> torch .fx .Graph :
70
71
"""
71
72
We monkey patch get_proxy_slot to support Tensor/SymInt/SymFloat/SymBool in the
72
73
graph without any origin for them. We instead insert _host_tensor(), _get_symnode()
@@ -122,14 +123,13 @@ def _get_proxy_slot(
122
123
current_location ().set_fx_location ()
123
124
return proxy_tensor .make_fx (fn , decomposition_table = select_decomp_table ())(
124
125
* args
125
- )
126
+ ). graph
126
127
127
128
128
129
@dataclasses .dataclass
129
130
class GraphInfo :
130
131
graph_id : int
131
- # TODO(jansel): GraphModule -> Graph to avoid fx compile
132
- graph : torch .fx .GraphModule
132
+ graph : torch .fx .Graph
133
133
134
134
@property
135
135
def name (self ) -> str :
@@ -140,7 +140,9 @@ def kwargs(self) -> dict[str, object]:
140
140
return {}
141
141
142
142
def __str__ (self ) -> str :
143
- output = self .graph .print_readable (print_output = False ).strip ()
143
+ output = (
144
+ _LazyGraphModule ({}, self .graph ).print_readable (print_output = False ).strip ()
145
+ )
144
146
return textwrap .dedent (
145
147
re .sub (
146
148
r"forward\(self,? ?([^)]*)\)" ,
@@ -251,7 +253,7 @@ def __init__(self) -> None:
251
253
self .rolled_reductions : list [RolledReductionInfo ] = []
252
254
self .grid_block_indices : list [list [int ]] = []
253
255
254
- def get_root (self , config : Config ) -> torch .fx .GraphModule :
256
+ def get_root (self , config : Config ) -> torch .fx .Graph :
255
257
""" " If we are using a rolled reduction, return the rolled reduction graph otherwise
256
258
return the root graph."""
257
259
if (root_id := self .root_id ) is None :
@@ -276,18 +278,18 @@ def debug_str(self) -> str:
276
278
277
279
def add_graph (
278
280
self ,
279
- graph : torch .fx .GraphModule ,
281
+ graph : torch .fx .Graph ,
280
282
graph_info_cls : type [GraphInfo ] = GraphInfo ,
281
283
** kwargs : object ,
282
284
) -> int :
283
- graph .graph . eliminate_dead_code ()
285
+ graph .eliminate_dead_code ()
284
286
graph_id = len (self .graphs )
285
287
self .graphs .append (graph_info_cls (graph_id = graph_id , graph = graph , ** kwargs ))
286
288
return graph_id
287
289
288
290
def add_reduction_loop_graph (
289
291
self ,
290
- graph : torch .fx .GraphModule ,
292
+ graph : torch .fx .Graph ,
291
293
block_index : int ,
292
294
node_args : list [torch .fx .Node ],
293
295
) -> int :
@@ -298,7 +300,7 @@ def add_reduction_loop_graph(
298
300
node_args = node_args ,
299
301
)
300
302
301
- def add_root_graph (self , graph : torch .fx .GraphModule ) -> None :
303
+ def add_root_graph (self , graph : torch .fx .Graph ) -> None :
302
304
assert self .root_id is None
303
305
self .root_id = self .add_graph (graph , graph_info_cls = RootGraphInfo )
304
306
@@ -314,9 +316,7 @@ def build_rolled_reductions(self) -> None:
314
316
for graph_id , graph_info in enumerate ([* self .graphs ]):
315
317
assert graph_id == graph_info .graph_id
316
318
roller = ReductionRoller (self , rdim , graph_to_info )
317
- new_graph = torch .fx .GraphModule (
318
- {}, roller .process (graph_info .graph .graph )
319
- )
319
+ new_graph = roller .process (graph_info .graph )
320
320
new_graph_id = self .add_graph (
321
321
new_graph , type (graph_info ), ** graph_info .kwargs ()
322
322
)
@@ -540,7 +540,7 @@ def run_subgraph(*args: object) -> list[object]:
540
540
with self .disable_tracing () as tracer :
541
541
graph = proxy_tensor .make_fx (
542
542
run_subgraph , decomposition_table = select_decomp_table ()
543
- )(* inputs .get_tensor_args ())
543
+ )(* inputs .get_tensor_args ()). graph
544
544
graph_idx = self .device_ir .add_graph (
545
545
graph ,
546
546
ForLoopGraphInfo ,
@@ -623,7 +623,7 @@ def run_body(*args: object) -> list[object]:
623
623
with self .disable_tracing () as tracer :
624
624
body_graph = proxy_tensor .make_fx (
625
625
run_body , decomposition_table = select_decomp_table ()
626
- )(* inputs .get_tensor_args ())
626
+ )(* inputs .get_tensor_args ()). graph
627
627
assert outputs is not None
628
628
graph_idx = self .device_ir .add_graph (
629
629
body_graph ,
@@ -843,8 +843,8 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
843
843
for graph in device_ir .graphs :
844
844
prepare_graph_lowerings (graph .graph )
845
845
for graph in device_ir .graphs :
846
- remove_unnecessary_tile_index (graph .graph . graph )
847
- remove_unnecessary_masking (graph .graph . graph )
846
+ remove_unnecessary_tile_index (graph .graph )
847
+ remove_unnecessary_masking (graph .graph )
848
848
device_ir .build_rolled_reductions ()
849
849
return device_ir
850
850
0 commit comments