Skip to content

Commit 7c7d915

Browse files
committed
Fix bug where non-tensor variables are not exposed to inner loops
1 parent 2c01ccb commit 7c7d915

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

helion/_compiler/device_ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,13 @@ def visit_For(self, node: ast.For) -> None:
435435
def run_subgraph(*args: object) -> list[object]:
436436
nonlocal outputs
437437
subgraph_walker = WalkDeviceAST(self.device_ir)
438+
subgraph_walker.scope.update(
439+
{
440+
k: v
441+
for k, v in self.scope.items()
442+
if not self.should_become_arg(v)
443+
}
444+
)
438445
subgraph_walker.scope.update(inputs.replace_tensor_args(args))
439446
subgraph_walker._assign(node.target, inner_type.proxy())
440447
subgraph_walker._body(node.body)

test/test_loops.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
from pathlib import Path
45
import unittest
56

@@ -367,3 +368,76 @@ def fn(x: torch.Tensor, block_size: int) -> torch.Tensor:
367368
)
368369
torch.testing.assert_close(result, torch.sin(args[0]))
369370
self.assertExpectedInline(code, """""")
371+
372+
def test_three_level_matmul(self):
373+
@helion.kernel(static_shapes=True)
374+
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
375+
m, k = x.size()
376+
k2, n = y.size()
377+
assert k == k2, f"size mismatch {k} != {k2}"
378+
out = torch.empty(
379+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
380+
)
381+
for tile_m in hl.tile(m):
382+
for tile_n in hl.tile(n):
383+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
384+
for tile_k in hl.tile(k):
385+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
386+
out[tile_m, tile_n] = acc
387+
return out
388+
389+
args = (
390+
torch.randn([256, 512], device=DEVICE),
391+
torch.randn([512, 128], device=DEVICE),
392+
)
393+
code, result = code_and_output(matmul, args, block_sizes=[16, 64, 64])
394+
torch.testing.assert_close(
395+
result, functools.reduce(torch.matmul, args), atol=1e-1, rtol=1e-2
396+
)
397+
self.assertExpectedInline(
398+
code,
399+
"""\
400+
from __future__ import annotations
401+
402+
import torch
403+
import triton
404+
import triton.language as tl
405+
406+
@triton.jit
407+
def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
408+
pid_0 = tl.program_id(0)
409+
offset_0 = pid_0 * _BLOCK_SIZE_0
410+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
411+
for offset_1 in range(0, 128, _BLOCK_SIZE_1):
412+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
413+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
414+
for offset_2 in range(0, 512, _BLOCK_SIZE_2):
415+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
416+
acc_copy = acc
417+
load = tl.load(x + (indices_0[:, None] * 512 + indices_2[None, :] * 1), None)
418+
load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None)
419+
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
420+
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
421+
422+
def matmul(x: torch.Tensor, y: torch.Tensor):
423+
m, k = x.size()
424+
k2, n = y.size()
425+
assert k == k2, f'size mismatch {k} != {k2}'
426+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
427+
_BLOCK_SIZE_0 = 16
428+
_BLOCK_SIZE_1 = 64
429+
_BLOCK_SIZE_2 = 64
430+
_matmul_kernel[triton.cdiv(256, _BLOCK_SIZE_0),](x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
431+
return out
432+
433+
def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
434+
m, k = x.size()
435+
k2, n = y.size()
436+
assert k == k2, f'size mismatch {k} != {k2}'
437+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
438+
_BLOCK_SIZE_0 = 16
439+
_BLOCK_SIZE_1 = 64
440+
_BLOCK_SIZE_2 = 64
441+
from helion.runtime.precompile_shim import make_precompiler
442+
return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
443+
)

0 commit comments

Comments
 (0)