Skip to content

Commit 643d723

Browse files
committed
Replace torch.fx.GraphModule with torch.fx.Graph
stack-info: PR: #116, branch: jansel/stack/16
1 parent 3a3684d commit 643d723

File tree

4 files changed

+30
-29
lines changed

4 files changed

+30
-29
lines changed

helion/_compiler/device_ir.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from torch._dynamo.convert_frame import compile_lock
2121
from torch._inductor.decomposition import select_decomp_table
22+
from torch.fx._lazy_graph_module import _LazyGraphModule
2223
from torch.fx.experimental import proxy_tensor
2324
from torch.fx.traceback import preserve_node_meta
2425
from torch.utils import _pytree as pytree
@@ -66,7 +67,7 @@ class _TLS(Protocol):
6667
tls: _TLS = cast("_TLS", threading.local())
6768

6869

69-
def _make_fx(fn: Callable[..., object], *args: object) -> torch.fx.GraphModule:
70+
def _make_fx(fn: Callable[..., object], *args: object) -> torch.fx.Graph:
7071
"""
7172
We monkey patch get_proxy_slot to support Tensor/SymInt/SymFloat/SymBool in the
7273
graph without any origin for them. We instead insert _host_tensor(), _get_symnode()
@@ -122,14 +123,13 @@ def _get_proxy_slot(
122123
current_location().set_fx_location()
123124
return proxy_tensor.make_fx(fn, decomposition_table=select_decomp_table())(
124125
*args
125-
)
126+
).graph
126127

127128

128129
@dataclasses.dataclass
129130
class GraphInfo:
130131
graph_id: int
131-
# TODO(jansel): GraphModule -> Graph to avoid fx compile
132-
graph: torch.fx.GraphModule
132+
graph: torch.fx.Graph
133133

134134
@property
135135
def name(self) -> str:
@@ -140,7 +140,9 @@ def kwargs(self) -> dict[str, object]:
140140
return {}
141141

142142
def __str__(self) -> str:
143-
output = self.graph.print_readable(print_output=False).strip()
143+
output = (
144+
_LazyGraphModule({}, self.graph).print_readable(print_output=False).strip()
145+
)
144146
return textwrap.dedent(
145147
re.sub(
146148
r"forward\(self,? ?([^)]*)\)",
@@ -251,7 +253,7 @@ def __init__(self) -> None:
251253
self.rolled_reductions: list[RolledReductionInfo] = []
252254
self.grid_block_indices: list[list[int]] = []
253255

254-
def get_root(self, config: Config) -> torch.fx.GraphModule:
256+
def get_root(self, config: Config) -> torch.fx.Graph:
255257
""" " If we are using a rolled reduction, return the rolled reduction graph otherwise
256258
return the root graph."""
257259
if (root_id := self.root_id) is None:
@@ -276,18 +278,18 @@ def debug_str(self) -> str:
276278

277279
def add_graph(
278280
self,
279-
graph: torch.fx.GraphModule,
281+
graph: torch.fx.Graph,
280282
graph_info_cls: type[GraphInfo] = GraphInfo,
281283
**kwargs: object,
282284
) -> int:
283-
graph.graph.eliminate_dead_code()
285+
graph.eliminate_dead_code()
284286
graph_id = len(self.graphs)
285287
self.graphs.append(graph_info_cls(graph_id=graph_id, graph=graph, **kwargs))
286288
return graph_id
287289

288290
def add_reduction_loop_graph(
289291
self,
290-
graph: torch.fx.GraphModule,
292+
graph: torch.fx.Graph,
291293
block_index: int,
292294
node_args: list[torch.fx.Node],
293295
) -> int:
@@ -298,7 +300,7 @@ def add_reduction_loop_graph(
298300
node_args=node_args,
299301
)
300302

301-
def add_root_graph(self, graph: torch.fx.GraphModule) -> None:
303+
def add_root_graph(self, graph: torch.fx.Graph) -> None:
302304
assert self.root_id is None
303305
self.root_id = self.add_graph(graph, graph_info_cls=RootGraphInfo)
304306

@@ -314,9 +316,7 @@ def build_rolled_reductions(self) -> None:
314316
for graph_id, graph_info in enumerate([*self.graphs]):
315317
assert graph_id == graph_info.graph_id
316318
roller = ReductionRoller(self, rdim, graph_to_info)
317-
new_graph = torch.fx.GraphModule(
318-
{}, roller.process(graph_info.graph.graph)
319-
)
319+
new_graph = roller.process(graph_info.graph)
320320
new_graph_id = self.add_graph(
321321
new_graph, type(graph_info), **graph_info.kwargs()
322322
)
@@ -540,7 +540,7 @@ def run_subgraph(*args: object) -> list[object]:
540540
with self.disable_tracing() as tracer:
541541
graph = proxy_tensor.make_fx(
542542
run_subgraph, decomposition_table=select_decomp_table()
543-
)(*inputs.get_tensor_args())
543+
)(*inputs.get_tensor_args()).graph
544544
graph_idx = self.device_ir.add_graph(
545545
graph,
546546
ForLoopGraphInfo,
@@ -623,7 +623,7 @@ def run_body(*args: object) -> list[object]:
623623
with self.disable_tracing() as tracer:
624624
body_graph = proxy_tensor.make_fx(
625625
run_body, decomposition_table=select_decomp_table()
626-
)(*inputs.get_tensor_args())
626+
)(*inputs.get_tensor_args()).graph
627627
assert outputs is not None
628628
graph_idx = self.device_ir.add_graph(
629629
body_graph,
@@ -843,8 +843,8 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
843843
for graph in device_ir.graphs:
844844
prepare_graph_lowerings(graph.graph)
845845
for graph in device_ir.graphs:
846-
remove_unnecessary_tile_index(graph.graph.graph)
847-
remove_unnecessary_masking(graph.graph.graph)
846+
remove_unnecessary_tile_index(graph.graph)
847+
remove_unnecessary_masking(graph.graph)
848848
device_ir.build_rolled_reductions()
849849
return device_ir
850850

helion/_compiler/inductor_lowering.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch._inductor.utils import triton_type
3131
from torch._inductor.virtualized import OpsValue
3232
from torch._inductor.virtualized import V
33+
from torch.fx._lazy_graph_module import _LazyGraphModule
3334
from torch.fx.experimental import proxy_tensor
3435
from torch.fx.experimental.sym_node import SymNode
3536
from torch.fx.interpreter import Interpreter
@@ -65,14 +66,15 @@
6566
CodegenHandler = Callable[["GraphInterpreter", torch.fx.Node], object]
6667

6768

68-
def prepare_graph_lowerings(gm: torch.fx.GraphModule) -> None:
69+
def prepare_graph_lowerings(graph: torch.fx.Graph) -> None:
6970
with compile_lock:
7071
graph_lowering = GraphLowering(
71-
gm, shape_env=CompileEnvironment.current().shape_env
72+
_LazyGraphModule({}, graph),
73+
shape_env=CompileEnvironment.current().shape_env,
7274
)
7375
# pyre-ignore[19]
7476
with V.set_graph_handler(graph_lowering):
75-
for node in gm.graph.nodes:
77+
for node in graph.nodes:
7678
assert node.op in {
7779
"call_function",
7880
"placeholder",
@@ -815,8 +817,8 @@ def _unpack_opsvalue(value: object) -> str:
815817

816818

817819
class GraphInterpreter(Interpreter):
818-
def __init__(self, gm: torch.fx.GraphModule, cg: GenerateAST) -> None:
819-
super().__init__(gm, garbage_collect_values=False)
820+
def __init__(self, graph: torch.fx.Graph, cg: GenerateAST) -> None:
821+
super().__init__(_LazyGraphModule({}, graph), garbage_collect_values=False)
820822
self.cg = cg
821823

822824
def run_node(self, n: Node) -> object:
@@ -844,11 +846,11 @@ def run_node(self, n: Node) -> object:
844846

845847

846848
def codegen_call_with_graph(
847-
cg: GenerateAST, gm: torch.fx.GraphModule, args: list[ast.AST]
849+
cg: GenerateAST, graph: torch.fx.Graph, args: list[ast.AST]
848850
) -> list[object]:
849851
with compile_lock:
850852
new_args = []
851-
placeholders = gm.graph.find_nodes(op="placeholder")
853+
placeholders = graph.find_nodes(op="placeholder")
852854
for arg, placeholder in zip(args, placeholders, strict=True):
853855
if all(
854856
user.target == torch.ops.aten.sym_size.int for user in placeholder.users
@@ -864,7 +866,7 @@ def codegen_call_with_graph(
864866
new_args.append(expr_from_string(copy_name))
865867
else:
866868
new_args.append(cg.lift(arg))
867-
return GraphInterpreter(gm, cg).run(*new_args)
869+
return GraphInterpreter(graph, cg).run(*new_args)
868870

869871

870872
class CodegenState(NamedTuple):

helion/_compiler/node_masking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def cached_masked_value(
9090
"""
9191
device_ir = DeviceIR.current()
9292
for graph_info in device_ir.graphs:
93-
if node.graph is graph_info.graph.graph and isinstance(
93+
if node.graph is graph_info.graph and isinstance(
9494
graph_info, NodeArgsGraphInfo
9595
):
9696
outer_node = graph_info.placeholder_to_outer_arg(node)
@@ -142,7 +142,7 @@ def getitem_masked_value(
142142
else:
143143
return None
144144
assert isinstance(graph_id, int)
145-
graph = DeviceIR.current().graphs[graph_id].graph.graph
145+
graph = DeviceIR.current().graphs[graph_id].graph
146146
(output_node,) = graph.find_nodes(op="output")
147147
(outputs,) = output_node.args
148148
assert isinstance(outputs, (list, tuple))

helion/_compiler/roll_reduction.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,8 @@ def start_new_graph(self) -> None:
141141
self.available.add(orig_node)
142142
graph = self.inner_graph
143143
graph.output([*outputs.values()])
144-
gm = torch.fx.GraphModule({}, graph)
145144
graph_id = self.device_ir.add_reduction_loop_graph(
146-
gm,
145+
graph,
147146
block_index=self.rdim.block_size_idx,
148147
node_args=self.inner_args,
149148
)

0 commit comments

Comments
 (0)