Skip to content

Commit 125c256

Browse files
committed
Pass to remove unnecessary hl.tile_index calls
stack-info: PR: #115, branch: jansel/stack/15
1 parent b17d302 commit 125c256

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

helion/_compiler/device_ir.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,23 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
843843
for graph in device_ir.graphs:
844844
prepare_graph_lowerings(graph.graph)
845845
for graph in device_ir.graphs:
846+
remove_unnecessary_tile_index(graph.graph.graph)
846847
remove_unnecessary_masking(graph.graph.graph)
847848
device_ir.build_rolled_reductions()
848849
return device_ir
850+
851+
852+
def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
853+
"""
854+
Remove unnecessary tile_index nodes from the graph.
855+
Passing a tile directly to index results in less masking.
856+
"""
857+
for node in graph.find_nodes(op="call_function", target=hl.tile_index):
858+
for user in [*node.users]:
859+
if user.op == "call_function" and user.target in (hl.load, hl.store):
860+
new_args = [*user.args]
861+
assert isinstance(new_args[1], list)
862+
new_args[1] = [(node.args[0] if x is node else x) for x in new_args[1]]
863+
user.args = tuple(new_args)
864+
if len(node.users) == 0:
865+
graph.erase_node(node)

test/test_masking.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,63 @@ def _fn_make_precompiler(x):
267267
from helion.runtime.precompile_shim import make_precompiler
268268
return make_precompiler(_fn_kernel)(out, out.stride(0), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
269269
)
270+
271+
def test_tile_index_does_not_mask(self):
272+
@helion.kernel(config={"block_sizes": [32, 32], "indexing": "block_ptr"})
273+
def fn(x):
274+
m, n = x.size()
275+
out = torch.empty([m], device=x.device)
276+
block_size_n = hl.register_block_size(n)
277+
for tile_m in hl.tile(m):
278+
acc = hl.zeros([tile_m, block_size_n])
279+
for tile_n in hl.tile(0, n, block_size_n):
280+
acc += x[tile_m.index, tile_n.index]
281+
out[tile_m.index] = acc.sum(dim=1)
282+
return out
283+
284+
args = (torch.randn([100, 100], device=DEVICE),)
285+
code, result = code_and_output(
286+
fn,
287+
args,
288+
)
289+
torch.testing.assert_close(result, args[0].sum(dim=1))
290+
self.assertNotIn("tl.where", code)
291+
self.assertExpectedInline(
292+
code,
293+
"""\
294+
from __future__ import annotations
295+
296+
import torch
297+
import triton
298+
import triton.language as tl
299+
300+
@triton.jit
301+
def _fn_kernel(x, out, out_size_0, x_size_0, x_size_1, out_stride_0, x_stride_0, x_stride_1, n, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
302+
pid_0 = tl.program_id(0)
303+
offset_1 = pid_0 * _BLOCK_SIZE_1
304+
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
305+
for offset_0 in range(0, n.to(tl.int32), _BLOCK_SIZE_0):
306+
acc_copy = acc
307+
load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_1, offset_0], [_BLOCK_SIZE_1, _BLOCK_SIZE_0], [1, 0]), boundary_check=[0, 1], padding_option='zero')
308+
acc = acc_copy + load
309+
sum_1 = tl.sum(acc, 1)
310+
tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_1], [_BLOCK_SIZE_1], [0]), sum_1, boundary_check=[0])
311+
312+
def fn(x):
313+
m, n = x.size()
314+
out = torch.empty([m], device=x.device)
315+
block_size_n = 32
316+
_BLOCK_SIZE_1 = 32
317+
_BLOCK_SIZE_0 = 32
318+
_fn_kernel[triton.cdiv(m, _BLOCK_SIZE_1),](x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
319+
return out
320+
321+
def _fn_make_precompiler(x):
322+
m, n = x.size()
323+
out = torch.empty([m], device=x.device)
324+
block_size_n = 32
325+
_BLOCK_SIZE_1 = 32
326+
_BLOCK_SIZE_0 = 32
327+
from helion.runtime.precompile_shim import make_precompiler
328+
return make_precompiler(_fn_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
329+
)

0 commit comments

Comments
 (0)