@@ -267,3 +267,63 @@ def _fn_make_precompiler(x):
267
267
from helion.runtime.precompile_shim import make_precompiler
268
268
return make_precompiler(_fn_kernel)(out, out.stride(0), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""" ,
269
269
)
270
+
271
+ def test_tile_index_does_not_mask (self ):
272
+ @helion .kernel (config = {"block_sizes" : [32 , 32 ], "indexing" : "block_ptr" })
273
+ def fn (x ):
274
+ m , n = x .size ()
275
+ out = torch .empty ([m ], device = x .device )
276
+ block_size_n = hl .register_block_size (n )
277
+ for tile_m in hl .tile (m ):
278
+ acc = hl .zeros ([tile_m , block_size_n ])
279
+ for tile_n in hl .tile (0 , n , block_size_n ):
280
+ acc += x [tile_m .index , tile_n .index ]
281
+ out [tile_m .index ] = acc .sum (dim = 1 )
282
+ return out
283
+
284
+ args = (torch .randn ([100 , 100 ], device = DEVICE ),)
285
+ code , result = code_and_output (
286
+ fn ,
287
+ args ,
288
+ )
289
+ torch .testing .assert_close (result , args [0 ].sum (dim = 1 ))
290
+ self .assertNotIn ("tl.where" , code )
291
+ self .assertExpectedInline (
292
+ code ,
293
+ """\
294
+ from __future__ import annotations
295
+
296
+ import torch
297
+ import triton
298
+ import triton.language as tl
299
+
300
+ @triton.jit
301
+ def _fn_kernel(x, out, out_size_0, x_size_0, x_size_1, out_stride_0, x_stride_0, x_stride_1, n, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
302
+ pid_0 = tl.program_id(0)
303
+ offset_1 = pid_0 * _BLOCK_SIZE_1
304
+ acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
305
+ for offset_0 in range(0, n.to(tl.int32), _BLOCK_SIZE_0):
306
+ acc_copy = acc
307
+ load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_1, offset_0], [_BLOCK_SIZE_1, _BLOCK_SIZE_0], [1, 0]), boundary_check=[0, 1], padding_option='zero')
308
+ acc = acc_copy + load
309
+ sum_1 = tl.sum(acc, 1)
310
+ tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_1], [_BLOCK_SIZE_1], [0]), sum_1, boundary_check=[0])
311
+
312
+ def fn(x):
313
+ m, n = x.size()
314
+ out = torch.empty([m], device=x.device)
315
+ block_size_n = 32
316
+ _BLOCK_SIZE_1 = 32
317
+ _BLOCK_SIZE_0 = 32
318
+ _fn_kernel[triton.cdiv(m, _BLOCK_SIZE_1),](x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
319
+ return out
320
+
321
+ def _fn_make_precompiler(x):
322
+ m, n = x.size()
323
+ out = torch.empty([m], device=x.device)
324
+ block_size_n = 32
325
+ _BLOCK_SIZE_1 = 32
326
+ _BLOCK_SIZE_0 = 32
327
+ from helion.runtime.precompile_shim import make_precompiler
328
+ return make_precompiler(_fn_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""" ,
329
+ )
0 commit comments