Skip to content

Commit d648076

Browse files
committed
Improve mask optimization to cover control flow and inductor ops
stack-info: PR: #111, branch: jansel/stack/11
1 parent 2c55636 commit d648076

File tree

9 files changed

+367
-37
lines changed

9 files changed

+367
-37
lines changed

helion/_compiler/device_ir.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
import operator
99
import re
1010
import textwrap
11+
import threading
1112
from typing import TYPE_CHECKING
1213
from typing import Iterator
1314
from typing import NamedTuple
15+
from typing import Protocol
16+
from typing import cast
1417
from unittest.mock import patch
1518

1619
import torch
@@ -36,6 +39,7 @@
3639
from .inductor_lowering import CodegenState
3740
from .inductor_lowering import codegen_call_with_graph
3841
from .inductor_lowering import prepare_graph_lowerings
42+
from .node_masking import remove_unnecessary_masking
3943
from .roll_reduction import ReductionRoller
4044
from .source_location import current_location
4145
from .tile_index_proxy import CheckForIndexCalls
@@ -55,6 +59,12 @@
5559
from collections.abc import Callable
5660
from collections.abc import Sequence
5761

62+
class _TLS(Protocol):
63+
device_irs: list[DeviceIR]
64+
65+
66+
tls: _TLS = cast("_TLS", threading.local())
67+
5868

5969
def _make_fx(fn: Callable[..., object], *args: object) -> torch.fx.GraphModule:
6070
"""
@@ -151,7 +161,31 @@ def name(self) -> str:
151161

152162

153163
@dataclasses.dataclass
154-
class ForLoopGraphInfo(GraphInfo):
164+
class NodeArgsGraphInfo(GraphInfo):
165+
"""Common base class for graphs that have arguments from another graph."""
166+
167+
node_args: list[torch.fx.Node]
168+
169+
def placeholder_to_outer_arg(self, node: torch.fx.Node) -> torch.fx.Node:
170+
assert node.op == "placeholder"
171+
for placeholder, outer_node in zip(
172+
node.graph.find_nodes(op="placeholder"),
173+
self.node_args,
174+
strict=True,
175+
):
176+
if placeholder is node:
177+
return outer_node
178+
raise KeyError("Placeholder not found in node_args")
179+
180+
def kwargs(self) -> dict[str, object]:
181+
# TODO(jansel): do we need to map these to the new graph in the case of a copy?
182+
return {
183+
"node_args": [*self.node_args],
184+
}
185+
186+
187+
@dataclasses.dataclass
188+
class ForLoopGraphInfo(NodeArgsGraphInfo):
155189
block_indices: list[int]
156190

157191
@property
@@ -160,6 +194,7 @@ def name(self) -> str:
160194

161195
def kwargs(self) -> dict[str, object]:
162196
return {
197+
**super().kwargs(),
163198
"block_indices": [*self.block_indices],
164199
}
165200

@@ -179,14 +214,13 @@ def codegen(self, state: CodegenState) -> list[object]:
179214
)
180215

181216

182-
@dataclasses.dataclass
183217
class ReductionLoopGraphInfo(ForLoopGraphInfo):
184218
@property
185219
def name(self) -> str:
186220
return f"reduction_loop_{self.graph_id}"
187221

188222

189-
class IfGraphInfo(GraphInfo):
223+
class IfGraphInfo(NodeArgsGraphInfo):
190224
@property
191225
def name(self) -> str:
192226
return f"if_else_graph_{self.graph_id}"
@@ -252,12 +286,16 @@ def add_graph(
252286
return graph_id
253287

254288
def add_reduction_loop_graph(
255-
self, graph: torch.fx.GraphModule, block_index: int
289+
self,
290+
graph: torch.fx.GraphModule,
291+
block_index: int,
292+
node_args: list[torch.fx.Node],
256293
) -> int:
257294
return self.add_graph(
258295
graph,
259296
graph_info_cls=ReductionLoopGraphInfo,
260297
block_indices=[block_index],
298+
node_args=node_args,
261299
)
262300

263301
def add_root_graph(self, graph: torch.fx.GraphModule) -> None:
@@ -302,6 +340,19 @@ def build_rolled_reductions(self) -> None:
302340
)
303341
first = False
304342

343+
def __enter__(self) -> None:
344+
try:
345+
tls.device_irs.append(self)
346+
except AttributeError:
347+
tls.device_irs = [self]
348+
349+
def __exit__(self, *args: object) -> None:
350+
tls.device_irs.pop()
351+
352+
@staticmethod
353+
def current() -> DeviceIR:
354+
return tls.device_irs[-1]
355+
305356

306357
class WalkDeviceAST(NodeVisitor):
307358
def __init__(self, device_ir: DeviceIR) -> None:
@@ -494,6 +545,7 @@ def run_subgraph(*args: object) -> list[object]:
494545
graph,
495546
ForLoopGraphInfo,
496547
block_indices=[x.block_size_idx for x in iter_vars],
548+
node_args=inputs.get_node_args(tracer),
497549
)
498550
args = (
499551
graph_idx,
@@ -576,6 +628,7 @@ def run_body(*args: object) -> list[object]:
576628
graph_idx = self.device_ir.add_graph(
577629
body_graph,
578630
IfGraphInfo,
631+
node_args=inputs.get_node_args(tracer),
579632
)
580633
args = (
581634
test_proxy,
@@ -746,6 +799,16 @@ def replace_tensor_args(self, args: Sequence[object]) -> dict[str, object]:
746799
def get_tensor_args(self) -> list[object]:
747800
return [self.flat_values[i] for i in self.tensor_indices]
748801

802+
def get_node_args(
803+
self, tracer: proxy_tensor.PythonKeyTracer
804+
) -> list[torch.fx.Node]:
805+
proxy_args = args_to_proxies(tracer, self.get_tensor_args())[0]
806+
result = []
807+
for proxy in proxy_args:
808+
assert isinstance(proxy, torch.fx.Proxy)
809+
result.append(proxy.node)
810+
return result
811+
749812

750813
class WalkHostAST(NodeVisitor):
751814
def __init__(self, device_ir: DeviceIR) -> None:
@@ -771,13 +834,15 @@ def visit_For(self, node: ast.For) -> None:
771834

772835

773836
def lower_to_device_ir(func: HostFunction) -> DeviceIR:
774-
with func, compile_lock:
775-
device_ir = DeviceIR()
837+
device_ir = DeviceIR()
838+
with func, device_ir, compile_lock:
776839
visitor = WalkHostAST(device_ir)
777840
for stmt in func.body:
778841
visitor.visit(stmt)
779842
CompileEnvironment.current().errors.raise_if_errors()
780843
for graph in device_ir.graphs:
781844
prepare_graph_lowerings(graph.graph)
845+
for graph in device_ir.graphs:
846+
remove_unnecessary_masking(graph.graph.graph)
782847
device_ir.build_rolled_reductions()
783848
return device_ir

helion/_compiler/inductor_lowering.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from .compile_environment import CompileEnvironment
4747
from .node_masking import apply_masking
4848
from .node_masking import cached_masked_value
49+
from .node_masking import getitem_masked_value
50+
from .node_masking import inductor_masked_value
4951
from .node_masking import mask_node_inputs
5052
from .tile_strategy import TileStrategy
5153

@@ -372,9 +374,7 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
372374
return expr_from_string(output_name)
373375

374376
def get_masked_value(self, node: torch.fx.Node) -> float | bool | None:
375-
"""Get the masked value for this node."""
376-
# TODO(jansel): use valueranges to determine masked value
377-
return None
377+
return inductor_masked_value(self, node)
378378

379379

380380
@dataclasses.dataclass
@@ -465,10 +465,20 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
465465
node.meta["val"],
466466
)
467467

468+
def get_masked_value(self, node: torch.fx.Node) -> float | bool | None:
469+
# reduction types that preserve zeroness
470+
if self.reduction_type in {"sum", "prod", "min", "max"}:
471+
value = inductor_masked_value(self, node)
472+
if value == 0:
473+
return value
474+
return None
475+
468476

469-
@dataclasses.dataclass
470477
class APIFuncLowering(Lowering):
471-
api_func: APIFunc
478+
def __init__(self, api_func: object) -> None:
479+
super().__init__()
480+
assert is_api_func(api_func)
481+
self.api_func: APIFunc = api_func
472482

473483
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
474484
assert not node.kwargs
@@ -580,7 +590,7 @@ def codegen_sym_size(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
580590
return val
581591

582592

583-
@register_lowering(getitem)
593+
@register_lowering(getitem, masked_value_fn=getitem_masked_value)
584594
def codegen_getitem(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
585595
assert not node.kwargs, "getitem kwargs not supported"
586596
lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])

0 commit comments

Comments
 (0)