8
8
import operator
9
9
import re
10
10
import textwrap
11
+ import threading
11
12
from typing import TYPE_CHECKING
12
13
from typing import Iterator
13
14
from typing import NamedTuple
15
+ from typing import Protocol
16
+ from typing import cast
14
17
from unittest .mock import patch
15
18
16
19
import torch
36
39
from .inductor_lowering import CodegenState
37
40
from .inductor_lowering import codegen_call_with_graph
38
41
from .inductor_lowering import prepare_graph_lowerings
42
+ from .node_masking import remove_unnecessary_masking
39
43
from .roll_reduction import ReductionRoller
40
44
from .source_location import current_location
41
45
from .tile_index_proxy import CheckForIndexCalls
55
59
from collections .abc import Callable
56
60
from collections .abc import Sequence
57
61
62
+ class _TLS (Protocol ):
63
+ device_irs : list [DeviceIR ]
64
+
65
+
66
+ tls : _TLS = cast ("_TLS" , threading .local ())
67
+
58
68
59
69
def _make_fx (fn : Callable [..., object ], * args : object ) -> torch .fx .GraphModule :
60
70
"""
@@ -151,7 +161,31 @@ def name(self) -> str:
151
161
152
162
153
163
@dataclasses .dataclass
154
- class ForLoopGraphInfo (GraphInfo ):
164
+ class NodeArgsGraphInfo (GraphInfo ):
165
+ """Common base class for graphs that have arguments from another graph."""
166
+
167
+ node_args : list [torch .fx .Node ]
168
+
169
+ def placeholder_to_outer_arg (self , node : torch .fx .Node ) -> torch .fx .Node :
170
+ assert node .op == "placeholder"
171
+ for placeholder , outer_node in zip (
172
+ node .graph .find_nodes (op = "placeholder" ),
173
+ self .node_args ,
174
+ strict = True ,
175
+ ):
176
+ if placeholder is node :
177
+ return outer_node
178
+ raise KeyError ("Placeholder not found in node_args" )
179
+
180
+ def kwargs (self ) -> dict [str , object ]:
181
+ # TODO(jansel): do we need to map these to the new graph in the case of a copy?
182
+ return {
183
+ "node_args" : [* self .node_args ],
184
+ }
185
+
186
+
187
+ @dataclasses .dataclass
188
+ class ForLoopGraphInfo (NodeArgsGraphInfo ):
155
189
block_indices : list [int ]
156
190
157
191
@property
@@ -160,6 +194,7 @@ def name(self) -> str:
160
194
161
195
def kwargs (self ) -> dict [str , object ]:
162
196
return {
197
+ ** super ().kwargs (),
163
198
"block_indices" : [* self .block_indices ],
164
199
}
165
200
@@ -179,14 +214,13 @@ def codegen(self, state: CodegenState) -> list[object]:
179
214
)
180
215
181
216
182
- @dataclasses .dataclass
183
217
class ReductionLoopGraphInfo (ForLoopGraphInfo ):
184
218
@property
185
219
def name (self ) -> str :
186
220
return f"reduction_loop_{ self .graph_id } "
187
221
188
222
189
- class IfGraphInfo (GraphInfo ):
223
+ class IfGraphInfo (NodeArgsGraphInfo ):
190
224
@property
191
225
def name (self ) -> str :
192
226
return f"if_else_graph_{ self .graph_id } "
@@ -252,12 +286,16 @@ def add_graph(
252
286
return graph_id
253
287
254
288
def add_reduction_loop_graph (
255
- self , graph : torch .fx .GraphModule , block_index : int
289
+ self ,
290
+ graph : torch .fx .GraphModule ,
291
+ block_index : int ,
292
+ node_args : list [torch .fx .Node ],
256
293
) -> int :
257
294
return self .add_graph (
258
295
graph ,
259
296
graph_info_cls = ReductionLoopGraphInfo ,
260
297
block_indices = [block_index ],
298
+ node_args = node_args ,
261
299
)
262
300
263
301
def add_root_graph (self , graph : torch .fx .GraphModule ) -> None :
@@ -302,6 +340,19 @@ def build_rolled_reductions(self) -> None:
302
340
)
303
341
first = False
304
342
343
+ def __enter__ (self ) -> None :
344
+ try :
345
+ tls .device_irs .append (self )
346
+ except AttributeError :
347
+ tls .device_irs = [self ]
348
+
349
+ def __exit__ (self , * args : object ) -> None :
350
+ tls .device_irs .pop ()
351
+
352
+ @staticmethod
353
+ def current () -> DeviceIR :
354
+ return tls .device_irs [- 1 ]
355
+
305
356
306
357
class WalkDeviceAST (NodeVisitor ):
307
358
def __init__ (self , device_ir : DeviceIR ) -> None :
@@ -494,6 +545,7 @@ def run_subgraph(*args: object) -> list[object]:
494
545
graph ,
495
546
ForLoopGraphInfo ,
496
547
block_indices = [x .block_size_idx for x in iter_vars ],
548
+ node_args = inputs .get_node_args (tracer ),
497
549
)
498
550
args = (
499
551
graph_idx ,
@@ -576,6 +628,7 @@ def run_body(*args: object) -> list[object]:
576
628
graph_idx = self .device_ir .add_graph (
577
629
body_graph ,
578
630
IfGraphInfo ,
631
+ node_args = inputs .get_node_args (tracer ),
579
632
)
580
633
args = (
581
634
test_proxy ,
@@ -746,6 +799,16 @@ def replace_tensor_args(self, args: Sequence[object]) -> dict[str, object]:
746
799
def get_tensor_args (self ) -> list [object ]:
747
800
return [self .flat_values [i ] for i in self .tensor_indices ]
748
801
802
+ def get_node_args (
803
+ self , tracer : proxy_tensor .PythonKeyTracer
804
+ ) -> list [torch .fx .Node ]:
805
+ proxy_args = args_to_proxies (tracer , self .get_tensor_args ())[0 ]
806
+ result = []
807
+ for proxy in proxy_args :
808
+ assert isinstance (proxy , torch .fx .Proxy )
809
+ result .append (proxy .node )
810
+ return result
811
+
749
812
750
813
class WalkHostAST (NodeVisitor ):
751
814
def __init__ (self , device_ir : DeviceIR ) -> None :
@@ -771,13 +834,15 @@ def visit_For(self, node: ast.For) -> None:
771
834
772
835
773
836
def lower_to_device_ir (func : HostFunction ) -> DeviceIR :
774
- with func , compile_lock :
775
- device_ir = DeviceIR ()
837
+ device_ir = DeviceIR ()
838
+ with func , device_ir , compile_lock :
776
839
visitor = WalkHostAST (device_ir )
777
840
for stmt in func .body :
778
841
visitor .visit (stmt )
779
842
CompileEnvironment .current ().errors .raise_if_errors ()
780
843
for graph in device_ir .graphs :
781
844
prepare_graph_lowerings (graph .graph )
845
+ for graph in device_ir .graphs :
846
+ remove_unnecessary_masking (graph .graph .graph )
782
847
device_ir .build_rolled_reductions ()
783
848
return device_ir
0 commit comments