9
9
def baseline_sum (x : torch .Tensor ) -> torch .Tensor :
10
10
return x .sum (- 1 )
11
11
12
+ # Naive Reduction: Load the entire reduction dim at once, and reduce in reg.
13
+ @helion .kernel (config = helion .Config (block_sizes = [[1 ]], reduction_loops = [None ], num_warps = 32 , num_stages = 4 , indexing = 'block_ptr' ))
14
+ def longsum (x : torch .Tensor ) -> torch .Tensor :
15
+ m , _ = x .size ()
16
+ out = torch .empty (
17
+ [m ], dtype = x .dtype , device = x .device
18
+ )
19
+
20
+ for tile_m in hl .tile (m ):
21
+ out [tile_m ] = x [tile_m , :].sum (- 1 )
22
+ return out
12
23
24
+
25
+ # Looped reduction
13
26
@helion .kernel (
14
27
config = helion .Config (
15
28
block_sizes = [[1 ]],
16
- reduction_loops = [32768 ], # [None] for non-looped reduction, [tile_size] for looped reduction
29
+ reduction_loops = [32768 ], # [None] for naive reduction, [tile_size] for looped reduction
17
30
num_warps = 16 ,
18
31
num_stages = 5 ,
19
32
indexing = "pointer" ,
20
33
)
21
34
)
22
- def long_sum_reduction (x : torch .Tensor ) -> torch .Tensor :
35
+ def longsum_w_red_loop (x : torch .Tensor ) -> torch .Tensor :
23
36
m , _ = x .size ()
24
37
out = torch .empty (
25
38
[m ], dtype = x .dtype , device = x .device
@@ -32,7 +45,7 @@ def long_sum_reduction(x: torch.Tensor) -> torch.Tensor:
32
45
33
46
# This generates the same code as above, but manually implements looped reduction.
34
47
@helion .kernel (config = helion .Config (block_sizes = [[32768 ], [1 ]], num_warps = 16 , num_stages = 5 , indexing = 'pointer' ))
35
- def long_sum (x : torch .Tensor ) -> torch .Tensor :
48
+ def longsum_manual (x : torch .Tensor ) -> torch .Tensor :
36
49
m , n = x .size ()
37
50
out = torch .empty (
38
51
[m ], dtype = x .dtype , device = x .device
@@ -54,19 +67,24 @@ def check(m: int, n: int) -> None:
54
67
55
68
x = torch .randn ([m , n ], device = "cuda" , dtype = torch .float32 )
56
69
57
- helion_out = long_sum (x )
70
+ helion_out = longsum (x )
58
71
torch .testing .assert_close (helion_out , baseline_sum (x ), rtol = 1e-2 , atol = 1e-1 )
59
- print ("✅ Results Match ✅ Naive Looped Reduction" )
72
+ print ("✅ Results Match ✅ naive reduction" )
73
+
74
+ helion_red_loop_out = longsum_w_red_loop (x )
75
+ torch .testing .assert_close (helion_red_loop_out , baseline_sum (x ), rtol = 1e-2 , atol = 1e-1 )
76
+ print ("✅ Results Match ✅ Reduction Loop" )
60
77
61
- helion_red_out = long_sum_reduction (x )
62
- torch .testing .assert_close (helion_red_out , baseline_sum (x ), rtol = 1e-2 , atol = 1e-1 )
63
- print ("✅ Results Match ✅ Reduction Helion " )
78
+ helion_manual_out = longsum_manual (x )
79
+ torch .testing .assert_close (helion_manual_out , baseline_sum (x ), rtol = 1e-2 , atol = 1e-1 )
80
+ print ("✅ Results Match ✅ Manual Reduction Loop " )
64
81
65
- sec = do_bench (lambda : long_sum (x ))
66
- red_sec = do_bench (lambda : long_sum_reduction (x ))
82
+ sec = do_bench (lambda : longsum (x ))
83
+ loop_sec = do_bench (lambda : longsum_w_red_loop (x ))
84
+ manual_loop_sec = do_bench (lambda : longsum_manual (x ))
67
85
baseline_sec = do_bench (lambda : baseline_sum (x ))
68
86
print (
69
- f"Helion time: { sec :.4f} s, Helion Reduction Time: { red_sec :.4f} , torch time: { baseline_sec :.4f} , speedup: { baseline_sec / sec :.2f} x { baseline_sec / red_sec :.2f} x"
87
+ f"Helion Naive time: { sec :.4f} s, Helion Looped Time: { loop_sec :.4f} , Helion Manual Loop Time: { manual_loop_sec :.4f } torch time: { baseline_sec :.4f} , speedup: { baseline_sec / sec :.2f} x { baseline_sec / loop_sec :.2f } x { baseline_sec / manual_loop_sec :.2f} x"
70
88
)
71
89
72
90
0 commit comments