Skip to content

Fix bug where non-tensor variables are not exposed to inner loops #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,13 @@ def visit_For(self, node: ast.For) -> None:
def run_subgraph(*args: object) -> list[object]:
nonlocal outputs
subgraph_walker = WalkDeviceAST(self.device_ir)
subgraph_walker.scope.update(
{
k: v
for k, v in self.scope.items()
if not self.should_become_arg(v)
}
)
subgraph_walker.scope.update(inputs.replace_tensor_args(args))
subgraph_walker._assign(node.target, inner_type.proxy())
subgraph_walker._body(node.body)
Expand Down Expand Up @@ -519,6 +526,9 @@ def _create_if_subgraph(self, test_proxy: object, body: list[ast.stmt]) -> None:
def run_body(*args: object) -> list[object]:
nonlocal outputs
subgraph_walker = WalkDeviceAST(self.device_ir)
subgraph_walker.scope.update(
{k: v for k, v in self.scope.items() if not self.should_become_arg(v)}
)
subgraph_walker.scope.update(inputs.replace_tensor_args(args))
subgraph_walker._body(body)
outputs = LiftTensorArgs(
Expand Down
74 changes: 74 additions & 0 deletions test/test_loops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
from pathlib import Path
import unittest

Expand Down Expand Up @@ -367,3 +368,76 @@ def fn(x: torch.Tensor, block_size: int) -> torch.Tensor:
)
torch.testing.assert_close(result, torch.sin(args[0]))
self.assertExpectedInline(code, """""")

def test_three_level_matmul(self):
@helion.kernel(static_shapes=True)
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.size()
k2, n = y.size()
assert k == k2, f"size mismatch {k} != {k2}"
out = torch.empty(
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
)
for tile_m in hl.tile(m):
for tile_n in hl.tile(n):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
out[tile_m, tile_n] = acc
return out

args = (
torch.randn([256, 512], device=DEVICE),
torch.randn([512, 128], device=DEVICE),
)
code, result = code_and_output(matmul, args, block_sizes=[16, 64, 64])
torch.testing.assert_close(
result, functools.reduce(torch.matmul, args), atol=1e-1, rtol=1e-2
)
self.assertExpectedInline(
code,
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
for offset_1 in range(0, 128, _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
for offset_2 in range(0, 512, _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
acc_copy = acc
load = tl.load(x + (indices_0[:, None] * 512 + indices_2[None, :] * 1), None)
load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None)
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)

def matmul(x: torch.Tensor, y: torch.Tensor):
m, k = x.size()
k2, n = y.size()
assert k == k2, f'size mismatch {k} != {k2}'
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 64
_BLOCK_SIZE_2 = 64
_matmul_kernel[triton.cdiv(256, _BLOCK_SIZE_0),](x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
m, k = x.size()
k2, n = y.size()
assert k == k2, f'size mismatch {k} != {k2}'
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 64
_BLOCK_SIZE_2 = 64
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
)
Loading