Skip to content

Pass to remove unnecessary hl.tile_index calls #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,23 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
for graph in device_ir.graphs:
prepare_graph_lowerings(graph.graph)
for graph in device_ir.graphs:
remove_unnecessary_tile_index(graph.graph.graph)
remove_unnecessary_masking(graph.graph.graph)
device_ir.build_rolled_reductions()
return device_ir


def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
"""
Remove unnecessary tile_index nodes from the graph.
Passing a tile directly results block_ptrs being supported.
"""
for node in graph.find_nodes(op="call_function", target=hl.tile_index):
for user in [*node.users]:
if user.op == "call_function" and user.target in (hl.load, hl.store):
new_args = [*user.args]
assert isinstance(new_args[1], (list, tuple))
new_args[1] = [(node.args[0] if x is node else x) for x in new_args[1]]
user.args = tuple(new_args)
if len(node.users) == 0:
graph.erase_node(node)
60 changes: 60 additions & 0 deletions test/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,63 @@ def _fn_make_precompiler(x):
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(out, out.stride(0), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)

def test_tile_index_does_not_mask(self):
@helion.kernel(config={"block_sizes": [32, 32], "indexing": "block_ptr"})
def fn(x):
m, n = x.size()
out = torch.empty([m], device=x.device)
block_size_n = hl.register_block_size(n)
for tile_m in hl.tile(m):
acc = hl.zeros([tile_m, block_size_n])
for tile_n in hl.tile(0, n, block_size_n):
acc += x[tile_m.index, tile_n.index]
out[tile_m.index] = acc.sum(dim=1)
return out

args = (torch.randn([100, 100], device=DEVICE),)
code, result = code_and_output(
fn,
args,
)
torch.testing.assert_close(result, args[0].sum(dim=1))
self.assertNotIn("tl.where", code)
self.assertExpectedInline(
code,
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _fn_kernel(x, out, out_size_0, x_size_0, x_size_1, out_stride_0, x_stride_0, x_stride_1, n, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_1 = pid_0 * _BLOCK_SIZE_1
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
for offset_0 in range(0, n.to(tl.int32), _BLOCK_SIZE_0):
acc_copy = acc
load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_1, offset_0], [_BLOCK_SIZE_1, _BLOCK_SIZE_0], [1, 0]), boundary_check=[0, 1], padding_option='zero')
acc = acc_copy + load
sum_1 = tl.sum(acc, 1)
tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_1], [_BLOCK_SIZE_1], [0]), sum_1, boundary_check=[0])

def fn(x):
m, n = x.size()
out = torch.empty([m], device=x.device)
block_size_n = 32
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_0 = 32
_fn_kernel[triton.cdiv(m, _BLOCK_SIZE_1),](x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def _fn_make_precompiler(x):
m, n = x.size()
out = torch.empty([m], device=x.device)
block_size_n = 32
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_0 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)
Loading