1
1
from __future__ import annotations
2
- from unittest import result
3
2
4
3
import torch
5
4
6
5
import helion
7
6
import helion .language as hl
8
7
8
+
9
9
def baseline_sum (x : torch .Tensor ) -> torch .Tensor :
10
10
return x .sum (- 1 )
11
11
12
+
12
13
# Naive Reduction: Load the entire reduction dim at once, and reduce in reg.
13
- @helion .kernel (config = helion .Config (block_sizes = [[1 ]], reduction_loops = [None ], num_warps = 32 , num_stages = 4 , indexing = 'block_ptr' ))
14
+ @helion .kernel (
15
+ config = helion .Config (
16
+ block_sizes = [[1 ]],
17
+ reduction_loops = [None ],
18
+ num_warps = 32 ,
19
+ num_stages = 4 ,
20
+ indexing = "block_ptr" ,
21
+ )
22
+ )
14
23
def longsum (x : torch .Tensor ) -> torch .Tensor :
15
24
m , _ = x .size ()
16
- out = torch .empty (
17
- [m ], dtype = x .dtype , device = x .device
18
- )
25
+ out = torch .empty ([m ], dtype = x .dtype , device = x .device )
19
26
20
27
for tile_m in hl .tile (m ):
21
28
out [tile_m ] = x [tile_m , :].sum (- 1 )
@@ -26,37 +33,39 @@ def longsum(x: torch.Tensor) -> torch.Tensor:
26
33
@helion .kernel (
27
34
config = helion .Config (
28
35
block_sizes = [[1 ]],
29
- reduction_loops = [32768 ], # [None] for naive reduction, [tile_size] for looped reduction
36
+ reduction_loops = [
37
+ 32768
38
+ ], # [None] for naive reduction, [tile_size] for looped reduction
30
39
num_warps = 16 ,
31
40
num_stages = 5 ,
32
41
indexing = "pointer" ,
33
42
)
34
43
)
35
44
def longsum_w_red_loop (x : torch .Tensor ) -> torch .Tensor :
36
45
m , _ = x .size ()
37
- out = torch .empty (
38
- [m ], dtype = x .dtype , device = x .device
39
- )
46
+ out = torch .empty ([m ], dtype = x .dtype , device = x .device )
40
47
41
48
for tile_m in hl .tile (m ):
42
49
out [tile_m ] = x [tile_m , :].sum (- 1 )
43
50
return out
44
51
45
52
46
53
# This generates the same code as above, but manually implements looped reduction.
47
- @helion .kernel (config = helion .Config (block_sizes = [[32768 ], [1 ]], num_warps = 16 , num_stages = 5 , indexing = 'pointer' ))
54
+ @helion .kernel (
55
+ config = helion .Config (
56
+ block_sizes = [[32768 ], [1 ]], num_warps = 16 , num_stages = 5 , indexing = "pointer"
57
+ )
58
+ )
48
59
def longsum_manual (x : torch .Tensor ) -> torch .Tensor :
49
60
m , n = x .size ()
50
- out = torch .empty (
51
- [m ], dtype = x .dtype , device = x .device
52
- )
61
+ out = torch .empty ([m ], dtype = x .dtype , device = x .device )
53
62
54
63
# Call register_block_size to know block_size_n outside of the reduction loop.
55
64
block_size_n = hl .register_block_size (n )
56
65
57
66
for tile_m in hl .tile (m ):
58
67
acc = hl .zeros ([tile_m , block_size_n ], dtype = x .dtype )
59
- for tile_n in hl .tile (n , block_size = block_size_n ): # Reduction loop
68
+ for tile_n in hl .tile (n , block_size = block_size_n ): # Reduction loop
60
69
acc += x [tile_m , tile_n ]
61
70
out [tile_m ] = acc .sum (- 1 )
62
71
return out
@@ -72,7 +81,9 @@ def check(m: int, n: int) -> None:
72
81
print ("✅ Results Match ✅ naive reduction" )
73
82
74
83
helion_red_loop_out = longsum_w_red_loop (x )
75
- torch .testing .assert_close (helion_red_loop_out , baseline_sum (x ), rtol = 1e-2 , atol = 1e-1 )
84
+ torch .testing .assert_close (
85
+ helion_red_loop_out , baseline_sum (x ), rtol = 1e-2 , atol = 1e-1
86
+ )
76
87
print ("✅ Results Match ✅ Reduction Loop" )
77
88
78
89
helion_manual_out = longsum_manual (x )
@@ -89,7 +100,7 @@ def check(m: int, n: int) -> None:
89
100
90
101
91
102
def main () -> None :
92
- check (4 , 130000 ) # seq_len = 128k
103
+ check (4 , 130000 ) # seq_len = 128k
93
104
94
105
95
106
if __name__ == "__main__" :
0 commit comments