Skip to content

Commit 288e2af

Browse files
authored
Optimization pass to remove unneeded masking (#109)
1 parent 3ba8a0b commit 288e2af

19 files changed

+641
-184
lines changed

helion/_compiler/compile_environment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ def size_hint(self) -> int:
333333
assert isinstance(size, (int, torch.SymInt))
334334
return CompileEnvironment.current().size_hint(size)
335335

336+
def size_matches(self, numel: sympy.Expr | None) -> bool:
337+
if numel is None or not isinstance(self.size, (int, torch.SymInt)):
338+
return False
339+
return numel == self.numel
340+
336341
def mark_alternate_size(self, size: torch.SymInt | int | None) -> None:
337342
"""If a block size is used with a different size, we need to clear the hint to enable masking."""
338343
if isinstance(self.size, AutoSize):

helion/_compiler/indexing_strategy.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .ast_extension import expr_from_string
1414
from .compile_environment import CompileEnvironment
1515
from .host_function import HostFunction
16+
from .tile_strategy import DeviceLoopState
1617
from .tile_strategy import TileStrategy
1718
from .variable_origin import BlockSizeOrigin
1819

@@ -203,7 +204,7 @@ def compute_shape(
203204
tensor: torch.Tensor, index: list[object]
204205
) -> list[int | torch.SymInt]:
205206
assert isinstance(tensor, torch.Tensor)
206-
assert isinstance(index, (list, tuple))
207+
assert isinstance(index, (list, tuple)), index
207208
input_size = collections.deque(tensor.size())
208209
output_size = []
209210
for k in index:
@@ -455,8 +456,9 @@ def is_supported(
455456
index: list[object],
456457
extra_mask: ast.AST | None,
457458
) -> bool:
459+
# TODO(jansel): TensorDescriptor has some extra restrictions that are not captured here.
458460
if extra_mask is not None:
459-
# TODO(jansel): block_ptr with extra_mask
461+
# TODO(jansel): support block_ptr with extra_mask
460462
return False
461463
for k in index:
462464
if isinstance(k, torch.SymInt):
@@ -465,10 +467,27 @@ def is_supported(
465467
if isinstance(symbol, sympy.Symbol):
466468
origin = HostFunction.current().expr_to_origin.get(symbol)
467469
if origin and isinstance(origin.origin, BlockSizeOrigin):
470+
block_index = origin.origin.block_size_idx
468471
try:
469-
state.codegen.offset_var(origin.origin.block_size_idx)
472+
state.codegen.offset_var(block_index)
470473
except NotImplementedError:
471474
return False
475+
loop_state = state.codegen.active_device_loops[block_index][-1]
476+
if isinstance(loop_state, DeviceLoopState):
477+
"""
478+
Check for a corner case where the loop size does not match the tensor size.
479+
In this case, the block masking will be incorrect. So we check if the
480+
masking is needed and bail if it is.
481+
"""
482+
end = loop_state.end_bounds[block_index]
483+
if (
484+
not CompileEnvironment.current()
485+
.block_sizes[block_index]
486+
.size_matches(end)
487+
):
488+
assert state.fx_node is not None
489+
if "masked_value" in state.fx_node.meta:
490+
return False
472491
if isinstance(k, torch.Tensor):
473492
# indirect loads don't work with block_ptr
474493
return False

helion/_compiler/inductor_lowering.py

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
from torch._dynamo.convert_frame import compile_lock
1515
from torch._inductor import config as inductor_config
16+
from torch._inductor import ir
1617
from torch._inductor.codegen.simd import SIMDKernelFeatures
1718
from torch._inductor.codegen.simd import constant_repr
1819
from torch._inductor.codegen.triton import TritonKernel
@@ -43,6 +44,9 @@
4344
from .ast_extension import expr_from_string
4445
from .ast_extension import statement_from_string
4546
from .compile_environment import CompileEnvironment
47+
from .node_masking import apply_masking
48+
from .node_masking import cached_masked_value
49+
from .node_masking import mask_node_inputs
4650
from .tile_strategy import TileStrategy
4751

4852
if TYPE_CHECKING:
@@ -185,7 +189,9 @@ def convert_arg(arg: Node) -> TensorBox:
185189
)
186190
),
187191
)
188-
new_node.meta["lowering"] = lowering_cls(buffer, used_input_names)
192+
new_node.meta["lowering"] = lowering = lowering_cls(buffer, used_input_names)
193+
if isinstance(lowering, ReductionLowering):
194+
lowering.add_input_mask(new_node)
189195
nodes.append(new_node)
190196
extra_input_names.append(buffer.get_name())
191197

@@ -269,6 +275,10 @@ class Lowering:
269275
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
270276
raise NotImplementedError
271277

278+
def get_masked_value(self, node: torch.fx.Node) -> float | bool | None:
279+
"""Get the masked value for this node."""
280+
return None
281+
272282

273283
@dataclasses.dataclass
274284
class InductorLowering(Lowering):
@@ -361,6 +371,11 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
361371
output_name = _unpack_opsvalue(self.buffer.data.inner_fn(indices))
362372
return expr_from_string(output_name)
363373

374+
def get_masked_value(self, node: torch.fx.Node) -> float | bool | None:
375+
"""Get the masked value for this node."""
376+
# TODO(jansel): use valueranges to determine masked value
377+
return None
378+
364379

365380
@dataclasses.dataclass
366381
class ReductionLowering(InductorLowering):
@@ -383,6 +398,25 @@ def __init__(
383398
assert block_index is not None
384399
self.block_index: int = block_index
385400

401+
@property
402+
def reduction_type(self) -> str:
403+
reduction = self.buffer.data
404+
assert isinstance(reduction, Reduction)
405+
return reduction.reduction_type
406+
407+
def add_input_mask(self, node: torch.fx.Node) -> None:
408+
"""Modify the node to apply masking for the reduction if needed."""
409+
reduction_type = self.reduction_type
410+
input_dtype = None
411+
for inp in node.all_input_nodes:
412+
if isinstance(inp.meta["val"], torch.Tensor):
413+
input_dtype = inp.meta["val"].dtype
414+
break
415+
assert input_dtype is not None
416+
default = ir.Reduction.default_accumulator(reduction_type, input_dtype)
417+
assert isinstance(default, (float, int, bool))
418+
mask_node_inputs(node, default)
419+
386420
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
387421
reduction = self.buffer.data
388422
assert isinstance(reduction, Reduction)
@@ -463,6 +497,11 @@ def normalize_args_kwargs(
463497
node.args = (*bound.arguments.values(),)
464498
node.kwargs = {}
465499

500+
def get_masked_value(self, node: torch.fx.Node) -> float | bool | None:
501+
if self.api_func._get_masked_value is not None:
502+
return self.api_func._get_masked_value(node)
503+
return None
504+
466505

467506
@dataclasses.dataclass
468507
class SympyExprLowering(Lowering):
@@ -471,31 +510,61 @@ class SympyExprLowering(Lowering):
471510
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
472511
return expr_from_string(ctx.cg.device_function.user_sympy_expr(self.expr))
473512

513+
def get_masked_value(self, node: torch.fx.Node) -> float | bool | None:
514+
if isinstance(self.expr, sympy.Integer):
515+
return int(self.expr)
516+
if isinstance(self.expr, sympy.Float):
517+
return float(self.expr)
518+
return None
519+
474520

475521
@dataclasses.dataclass
476522
class LambdaLowering(Lowering):
477523
fn: Callable[..., object]
524+
masked_value_fn: Callable[[torch.fx.Node], float | bool | None] | None = None
478525

479526
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
480527
return self.fn(ctx, node)
481528

529+
def get_masked_value(self, node: torch.fx.Node) -> float | bool | None:
530+
if self.masked_value_fn is not None:
531+
return self.masked_value_fn(node)
532+
return None
533+
534+
535+
def passthrough_masked_value(node: torch.fx.Node) -> float | bool | None:
536+
for input_node in node.all_input_nodes:
537+
if isinstance(input_node.meta["val"], torch.Tensor):
538+
return cached_masked_value(input_node)
539+
return None
540+
482541

483542
aten_lowering_dispatch: dict[object, Callable[[torch.fx.Node], Lowering]] = {}
484543

485544

486-
def default_make_lowering(handler: CodegenHandler, node: torch.fx.Node) -> Lowering:
487-
return LambdaLowering(handler)
545+
def default_make_lowering(
546+
handler: CodegenHandler,
547+
node: torch.fx.Node,
548+
masked_value_fn: Callable[[torch.fx.Node], float | bool | None] | None = None,
549+
) -> Lowering:
550+
return LambdaLowering(handler, masked_value_fn=masked_value_fn)
488551

489552

490553
def register_lowering(
491554
fn: object,
492555
make_lowering: Callable[
493556
[CodegenHandler, torch.fx.Node], Lowering
494557
] = default_make_lowering,
558+
masked_value_fn: Callable[[torch.fx.Node], float | bool | None] | None = None,
495559
) -> Callable[[CodegenHandler], CodegenHandler]:
496560
def decorator(handler: CodegenHandler) -> CodegenHandler:
497561
assert fn not in aten_lowering_dispatch, f"Lowering for {fn} already registered"
498-
aten_lowering_dispatch[fn] = lambda node: make_lowering(handler, node)
562+
# pyre-ignore[28]
563+
aten_lowering_dispatch[fn] = lambda node: make_lowering(
564+
handler,
565+
node,
566+
masked_value_fn=masked_value_fn,
567+
)
499568
return handler
500569

501570
return decorator
@@ -521,7 +590,12 @@ def codegen_getitem(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
521590

522591

523592
# pyre-fixme[56]
524-
@register_lowering(torch.ops.aten.full.default)
593+
@register_lowering(
594+
torch.ops.aten.full.default,
595+
masked_value_fn=lambda n: (
596+
n.args[1] if isinstance(n.args[1], (int, float, bool)) else None
597+
),
598+
)
525599
def codegen_full(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
526600
env = CompileEnvironment.current()
527601
size, fill_value = map_arg(node.args, lambda n: n.meta["val"])
@@ -539,7 +613,9 @@ def codegen_full(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
539613

540614

541615
# pyre-fixme[56]
542-
@register_lowering(torch.ops.aten.unsqueeze.default)
616+
@register_lowering(
617+
torch.ops.aten.unsqueeze.default, masked_value_fn=passthrough_masked_value
618+
)
543619
def codegen_unsqueeze(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
544620
assert not node.kwargs, "getitem kwargs not supported"
545621
tensor, dim = map_arg(node.args, lambda arg: ctx.env[arg])
@@ -557,10 +633,14 @@ def codegen_unsqueeze(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
557633
)
558634

559635

560-
@register_lowering(torch.ops.aten.squeeze.dim)
561-
@register_lowering(torch.ops.aten.view.default)
636+
@register_lowering(torch.ops.aten.squeeze.dim, masked_value_fn=passthrough_masked_value)
637+
@register_lowering(
638+
torch.ops.aten.view.default, masked_value_fn=passthrough_masked_value
639+
)
562640
# pyre-fixme[56]
563-
@register_lowering(torch.ops.aten.reshape.default)
641+
@register_lowering(
642+
torch.ops.aten.reshape.default, masked_value_fn=passthrough_masked_value
643+
)
564644
def codegen_view(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
565645
assert not node.kwargs, "view kwargs not supported"
566646
tensor = map_arg(node.args[0], lambda arg: ctx.env[arg])
@@ -572,7 +652,9 @@ def codegen_view(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
572652

573653

574654
# pyre-fixme[56]
575-
@register_lowering(torch.ops.aten.permute.default)
655+
@register_lowering(
656+
torch.ops.aten.permute.default, masked_value_fn=passthrough_masked_value
657+
)
576658
def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
577659
assert not node.kwargs, "getitem kwargs not supported"
578660
tensor, dims = map_arg(node.args, lambda arg: ctx.env[arg])
@@ -586,7 +668,9 @@ def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
586668

587669

588670
# pyre-fixme[56]
589-
@register_lowering(torch.ops.aten.expand.default)
671+
@register_lowering(
672+
torch.ops.aten.expand.default, masked_value_fn=passthrough_masked_value
673+
)
590674
def codegen_expand(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
591675
assert not node.kwargs, "getitem kwargs not supported"
592676
tensor, _ = map_arg(node.args, lambda arg: ctx.env[arg])
@@ -606,7 +690,11 @@ def codegen_expand(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
606690
)
607691

608692

609-
def apply_dot_requirements(handler: CodegenHandler, node: torch.fx.Node) -> Lowering:
693+
def apply_dot_requirements(
694+
handler: CodegenHandler,
695+
node: torch.fx.Node,
696+
masked_value_fn: Callable[[torch.fx.Node], float | bool | None] | None = None,
697+
) -> Lowering:
610698
"""Apply min_dot_size requirements to the config_spec"""
611699
assert not node.kwargs, "dot kwargs not supported"
612700
assert len(node.args) in (2, 3)
@@ -625,7 +713,14 @@ def apply_dot_requirements(handler: CodegenHandler, node: torch.fx.Node) -> Lowe
625713
block_idx = TileStrategy.get_block_index(shape)
626714
if block_idx is not None:
627715
env.block_sizes[block_idx].update_min_block(min_size, allow_flattened=True)
628-
return LambdaLowering(handler)
716+
# inputs to the dot operation must be zero-masked
717+
*maybe_acc, lnode, rnode = node.args
718+
assert isinstance(lnode, torch.fx.Node)
719+
assert isinstance(rnode, torch.fx.Node)
720+
lnode = apply_masking(lnode, base_node=node, other=0)
721+
rnode = apply_masking(rnode, base_node=node, other=0)
722+
node.args = (*maybe_acc, lnode, rnode)
723+
return LambdaLowering(handler, masked_value_fn=masked_value_fn)
629724

630725

631726
@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements)

helion/_compiler/node_masking.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
5+
import torch.fx
6+
from torch.fx.experimental import proxy_tensor
7+
8+
from helion.language._tracing_ops import _mask_to
9+
10+
11+
def mask_node_inputs(
12+
node: torch.fx.Node,
13+
other: float | bool = 0,
14+
) -> None:
15+
"""Inplace update the node's args and kwargs to apply masking if needed."""
16+
apply = functools.partial(apply_masking, other=other, base_node=node)
17+
node.args = torch.fx.map_arg(node.args, apply)
18+
node.kwargs = torch.fx.map_arg(node.kwargs, apply)
19+
20+
21+
def apply_masking(
22+
node: torch.fx.Node,
23+
*,
24+
base_node: torch.fx.Node,
25+
other: float | bool = 0,
26+
) -> torch.fx.Node:
27+
"""Analyze the node and apply masking if needed."""
28+
current_mask = cached_masked_value(node)
29+
if current_mask == other:
30+
return node # already masked, no need to change it
31+
for user in node.users:
32+
if user.op == "call_function" and user.target == _mask_to:
33+
if user.args[1] == other:
34+
assert user.args[0] is node
35+
return user # reuse existing mask_to node
36+
from helion._compiler.inductor_lowering import APIFuncLowering
37+
38+
# If we reach here, we need to create a new mask_to node
39+
with node.graph.inserting_before(base_node):
40+
new_node = node.graph.call_function(_mask_to, (node, other), {})
41+
new_node.meta.update(base_node.meta)
42+
with proxy_tensor.disable_proxy_modes_tracing():
43+
new_node.meta["val"] = node.meta["val"].clone()
44+
# pyre-ignore[6]
45+
new_node.meta["lowering"] = APIFuncLowering(_mask_to)
46+
new_node.meta["masked_value"] = other
47+
return new_node
48+
49+
50+
def cached_masked_value(
51+
node: torch.fx.Node,
52+
) -> float | bool | None:
53+
"""Determine the current masked value for the node."""
54+
if "masked_value" in node.meta:
55+
return node.meta["masked_value"]
56+
if node.op != "call_function":
57+
return None
58+
node.meta["masked_value"] = result = node.meta["lowering"].get_masked_value(node)
59+
return result

0 commit comments

Comments
 (0)