Skip to content

Commit 821784e

Browse files
committed
Remove 'first_non_grid_index' for hl.grid index
I don't think this was correct, since the grid index shouldn't get expanded. stack-info: PR: #113, branch: jansel/stack/13
1 parent be5e3bd commit 821784e

File tree

3 files changed

+72
-36
lines changed

3 files changed

+72
-36
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ refers to hydrogen-3.
2323
Helion can be viewed either as *PyTorch with tiles* or as *a higher-level Triton*. Compared to
2424
Triton, Helion reduces manual coding effort through autotuning. Helion spends more time (approx
2525
10 min) autotuning as it evaluates hundreds of potential Triton implementations generated
26-
from a single Helion kernel. This larger search space also makes kernels more performance
27-
portable between different hardware. Helion automates and autotunes over:
26+
from a single Helion kernel. This larger search space also makes kernels more performance
27+
portable between different hardware. Helion automates and autotunes over:
2828

2929
1. **Tensor Indexing:**
3030

@@ -37,7 +37,7 @@ portable between different hardware. Helion automates and autotunes over:
3737

3838
3. **Grid Sizes and PID Calculations:**
3939

40-
* Automatically determines grid sizes.
40+
* Automatically determines grid sizes.
4141
* Autotunes multiple mappings from Program IDs (PIDs) to data tiles.
4242

4343
4. **Implicit Search Space Definition:**

helion/_compiler/indexing_strategy.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ def create(
259259
index_values = []
260260
mask_values = {}
261261
output_size = SubscriptIndexing.compute_shape(fake_value, index)
262-
dtype = CompileEnvironment.current().triton_index_type()
263-
first_non_grid_index = 0
262+
env = CompileEnvironment.current()
263+
dtype = env.triton_index_type()
264264
for n, k in enumerate(index):
265265
if k is None:
266266
output_idx += 1
@@ -272,18 +272,11 @@ def create(
272272
if isinstance(symbol, sympy.Symbol):
273273
origin = HostFunction.current().expr_to_origin.get(symbol)
274274
if origin and isinstance(origin.origin, BlockSizeOrigin):
275-
if (
276-
CompileEnvironment.current()
277-
.block_sizes[origin.origin.block_size_idx]
278-
.is_grid()
279-
):
280-
first_non_grid_index = n + 1
281-
expand = tile_strategy.expand_str(output_size, output_idx)
282-
else:
283-
expand = tile_strategy.expand_str(
284-
output_size, output_idx - first_non_grid_index
285-
)
286275
index_var = state.codegen.index_var(origin.origin.block_size_idx)
276+
if env.block_sizes[origin.origin.block_size_idx].is_grid():
277+
index_values.append(index_var)
278+
continue
279+
expand = tile_strategy.expand_str(output_size, output_idx)
287280
i = len(index_values)
288281
index_values.append(f"({index_var}){expand}")
289282
if (
@@ -292,18 +285,13 @@ def create(
292285
mask_values.setdefault(f"({mask}){expand}")
293286
output_idx += 1
294287
else:
295-
expand = tile_strategy.expand_str(
296-
output_size, output_idx - first_non_grid_index
297-
)
288+
expand = tile_strategy.expand_str(output_size, output_idx)
298289
val = state.device_function.literal_expr(k)
299290
index_values.append(f"tl.full([1], {val}, {dtype}){expand}")
300291
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
301-
expand = tile_strategy.expand_str(
302-
output_size, output_idx - first_non_grid_index
303-
)
292+
expand = tile_strategy.expand_str(output_size, output_idx)
304293
size = fake_value.size(len(index_values))
305294
if size != 1:
306-
env = CompileEnvironment.current()
307295
rdim = env.allocate_reduction_dimension(size)
308296
block_idx = rdim.block_size_idx
309297
index_var = state.codegen.index_var(block_idx)
@@ -314,18 +302,14 @@ def create(
314302
index_values.append(f"tl.zeros([1], {dtype}){expand}")
315303
output_idx += 1
316304
elif isinstance(k, torch.Tensor) and k.ndim == 1:
317-
expand = tile_strategy.expand_str(
318-
output_size, output_idx - first_non_grid_index
319-
)
305+
expand = tile_strategy.expand_str(output_size, output_idx)
320306
ast_index = state.ast_args[1]
321307
assert isinstance(ast_index, (list, tuple))
322308
assert len(ast_index) == len(index)
323309
index_var = state.codegen.lift(ast_index[n]).id
324310
index_values.append(f"({index_var}){expand}")
325311
if (
326-
block_idx := TileStrategy.get_block_index(
327-
output_size[output_idx - first_non_grid_index]
328-
)
312+
block_idx := TileStrategy.get_block_index(output_size[output_idx])
329313
) is not None:
330314
if mask := state.codegen.mask_var(block_idx):
331315
mask_values.setdefault(f"({mask}){expand}")
@@ -349,7 +333,7 @@ def create(
349333
)
350334
else:
351335
raise exc.InvalidIndexingType(type(k))
352-
assert len(output_size) == output_idx - first_non_grid_index
336+
assert len(output_size) == output_idx
353337
assert len(index_values) == fake_value.ndim
354338
index_expr = []
355339
for i, idx in enumerate(index_values):

test/test_loops.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,63 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
517517
for offset_3 in range(0, 32, _BLOCK_SIZE_3):
518518
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
519519
acc_copy = acc
520-
load = tl.load(x + (indices_0[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
520+
load = tl.load(x + (indices_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
521521
load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
522522
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
523523
v_0 = acc.to(tl.float16)
524-
tl.store(out + (indices_0[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
524+
tl.store(out + (indices_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
525+
526+
def grid_1d(x: torch.Tensor, y: torch.Tensor):
527+
b, m, k = x.size()
528+
k2, n = y.size()
529+
assert k == k2, f'size mismatch {k} != {k2}'
530+
out = torch.empty(b, m, n, dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
531+
_BLOCK_SIZE_2 = 16
532+
_BLOCK_SIZE_1 = 16
533+
_BLOCK_SIZE_3 = 16
534+
_grid_1d_kernel[8,](x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
535+
return out
536+
537+
def _grid_1d_make_precompiler(x: torch.Tensor, y: torch.Tensor):
538+
b, m, k = x.size()
539+
k2, n = y.size()
540+
assert k == k2, f'size mismatch {k} != {k2}'
541+
out = torch.empty(b, m, n, dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
542+
_BLOCK_SIZE_2 = 16
543+
_BLOCK_SIZE_1 = 16
544+
_BLOCK_SIZE_3 = 16
545+
from helion.runtime.precompile_shim import make_precompiler
546+
return make_precompiler(_grid_1d_kernel)(x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)""",
547+
)
548+
549+
# test again with block_ptr indexing
550+
code, result = code_and_output(
551+
grid_1d, args, block_sizes=[[16, 16], 16], indexing="block_ptr"
552+
)
553+
torch.testing.assert_close(result, grid_1d_pytorch(args[0], args[1]))
554+
self.assertExpectedInline(
555+
code,
556+
"""\
557+
from __future__ import annotations
558+
559+
import torch
560+
import triton
561+
import triton.language as tl
562+
563+
@triton.jit
564+
def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
565+
pid_0 = tl.program_id(0)
566+
offset_0 = pid_0
567+
for offset_1 in range(0, 16, _BLOCK_SIZE_1):
568+
for offset_2 in range(0, 4, _BLOCK_SIZE_2):
569+
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
570+
for offset_3 in range(0, 32, _BLOCK_SIZE_3):
571+
acc_copy = acc
572+
load = tl.reshape(tl.load(tl.make_block_ptr(x, [8, 16, 32], [512, 32, 1], [offset_0, offset_1, offset_3], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero'), [_BLOCK_SIZE_1, _BLOCK_SIZE_3])
573+
load_1 = tl.load(tl.make_block_ptr(y, [32, 4], [4, 1], [offset_3, offset_2], [_BLOCK_SIZE_3, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
574+
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
575+
v_0 = acc.to(tl.float16)
576+
tl.store(tl.make_block_ptr(out, [8, 16, 4], [64, 4, 1], [offset_0, offset_1, offset_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), tl.reshape(v_0, [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2]), boundary_check=[0, 1, 2])
525577
526578
def grid_1d(x: torch.Tensor, y: torch.Tensor):
527579
b, m, k = x.size()
@@ -603,11 +655,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
603655
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
604656
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
605657
acc_copy = acc
606-
load = tl.load(x + (indices_0[:, None] * 8192 + indices_1[None, :] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
658+
load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
607659
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
608660
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
609661
v_0 = acc.to(tl.float16)
610-
tl.store(out + (indices_0[:, None] * 4096 + indices_1[None, :] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
662+
tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
611663
612664
def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor):
613665
bi, bj, m, k = x.size()
@@ -741,11 +793,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
741793
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
742794
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
743795
acc_copy = acc
744-
load = tl.load(x + (indices_0[:, None] * 8192 + indices_1[None, :] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
796+
load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
745797
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
746798
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
747799
v_0 = acc.to(tl.float16)
748-
tl.store(out + (indices_0[:, None] * 4096 + indices_1[None, :] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
800+
tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
749801
750802
def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor):
751803
bi, bj, m, k = x.size()

0 commit comments

Comments
 (0)