@@ -517,11 +517,63 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
517
517
for offset_3 in range(0, 32, _BLOCK_SIZE_3):
518
518
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
519
519
acc_copy = acc
520
- load = tl.load(x + (indices_0[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
520
+ load = tl.load(x + (indices_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
521
521
load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
522
522
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
523
523
v_0 = acc.to(tl.float16)
524
- tl.store(out + (indices_0[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
524
+ tl.store(out + (indices_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
525
+
526
+ def grid_1d(x: torch.Tensor, y: torch.Tensor):
527
+ b, m, k = x.size()
528
+ k2, n = y.size()
529
+ assert k == k2, f'size mismatch {k} != {k2}'
530
+ out = torch.empty(b, m, n, dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
531
+ _BLOCK_SIZE_2 = 16
532
+ _BLOCK_SIZE_1 = 16
533
+ _BLOCK_SIZE_3 = 16
534
+ _grid_1d_kernel[8,](x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
535
+ return out
536
+
537
+ def _grid_1d_make_precompiler(x: torch.Tensor, y: torch.Tensor):
538
+ b, m, k = x.size()
539
+ k2, n = y.size()
540
+ assert k == k2, f'size mismatch {k} != {k2}'
541
+ out = torch.empty(b, m, n, dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
542
+ _BLOCK_SIZE_2 = 16
543
+ _BLOCK_SIZE_1 = 16
544
+ _BLOCK_SIZE_3 = 16
545
+ from helion.runtime.precompile_shim import make_precompiler
546
+ return make_precompiler(_grid_1d_kernel)(x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)""" ,
547
+ )
548
+
549
+ # test again with block_ptr indexing
550
+ code , result = code_and_output (
551
+ grid_1d , args , block_sizes = [[16 , 16 ], 16 ], indexing = "block_ptr"
552
+ )
553
+ torch .testing .assert_close (result , grid_1d_pytorch (args [0 ], args [1 ]))
554
+ self .assertExpectedInline (
555
+ code ,
556
+ """\
557
+ from __future__ import annotations
558
+
559
+ import torch
560
+ import triton
561
+ import triton.language as tl
562
+
563
+ @triton.jit
564
+ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
565
+ pid_0 = tl.program_id(0)
566
+ offset_0 = pid_0
567
+ for offset_1 in range(0, 16, _BLOCK_SIZE_1):
568
+ for offset_2 in range(0, 4, _BLOCK_SIZE_2):
569
+ acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
570
+ for offset_3 in range(0, 32, _BLOCK_SIZE_3):
571
+ acc_copy = acc
572
+ load = tl.reshape(tl.load(tl.make_block_ptr(x, [8, 16, 32], [512, 32, 1], [offset_0, offset_1, offset_3], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero'), [_BLOCK_SIZE_1, _BLOCK_SIZE_3])
573
+ load_1 = tl.load(tl.make_block_ptr(y, [32, 4], [4, 1], [offset_3, offset_2], [_BLOCK_SIZE_3, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
574
+ acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
575
+ v_0 = acc.to(tl.float16)
576
+ tl.store(tl.make_block_ptr(out, [8, 16, 4], [64, 4, 1], [offset_0, offset_1, offset_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), tl.reshape(v_0, [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2]), boundary_check=[0, 1, 2])
525
577
526
578
def grid_1d(x: torch.Tensor, y: torch.Tensor):
527
579
b, m, k = x.size()
@@ -603,11 +655,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
603
655
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
604
656
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
605
657
acc_copy = acc
606
- load = tl.load(x + (indices_0[:, None] * 8192 + indices_1[None, :] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
658
+ load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
607
659
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
608
660
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
609
661
v_0 = acc.to(tl.float16)
610
- tl.store(out + (indices_0[:, None] * 4096 + indices_1[None, :] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
662
+ tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
611
663
612
664
def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor):
613
665
bi, bj, m, k = x.size()
@@ -741,11 +793,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
741
793
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
742
794
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
743
795
acc_copy = acc
744
- load = tl.load(x + (indices_0[:, None] * 8192 + indices_1[None, :] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
796
+ load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
745
797
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
746
798
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
747
799
v_0 = acc.to(tl.float16)
748
- tl.store(out + (indices_0[:, None] * 4096 + indices_1[None, :] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
800
+ tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
749
801
750
802
def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor):
751
803
bi, bj, m, k = x.size()
0 commit comments