Skip to content

Commit cb5ddcd

Browse files
authored
Add reduction example: Long sum (#92)
1 parent 864df06 commit cb5ddcd

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

examples/long_sum.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
import helion.language as hl
7+
8+
9+
def baseline_sum(x: torch.Tensor) -> torch.Tensor:
10+
return x.sum(-1)
11+
12+
13+
# Naive Reduction: Load the entire reduction dim at once, and reduce in reg.
14+
@helion.kernel(
15+
config=helion.Config(
16+
block_sizes=[[1]],
17+
reduction_loops=[None],
18+
num_warps=32,
19+
num_stages=4,
20+
indexing="block_ptr",
21+
)
22+
)
23+
def longsum(x: torch.Tensor) -> torch.Tensor:
24+
m, _ = x.size()
25+
out = torch.empty([m], dtype=x.dtype, device=x.device)
26+
27+
for tile_m in hl.tile(m):
28+
out[tile_m] = x[tile_m, :].sum(-1)
29+
return out
30+
31+
32+
# Looped reduction
33+
@helion.kernel(
34+
config=helion.Config(
35+
block_sizes=[[1]],
36+
reduction_loops=[
37+
32768
38+
], # [None] for naive reduction, [tile_size] for looped reduction
39+
num_warps=16,
40+
num_stages=5,
41+
indexing="pointer",
42+
)
43+
)
44+
def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor:
45+
m, _ = x.size()
46+
out = torch.empty([m], dtype=x.dtype, device=x.device)
47+
48+
for tile_m in hl.tile(m):
49+
out[tile_m] = x[tile_m, :].sum(-1)
50+
return out
51+
52+
53+
# This generates the same code as above, but manually implements looped reduction.
54+
@helion.kernel(
55+
config=helion.Config(
56+
block_sizes=[[32768], [1]], num_warps=16, num_stages=5, indexing="pointer"
57+
)
58+
)
59+
def longsum_manual(x: torch.Tensor) -> torch.Tensor:
60+
m, n = x.size()
61+
out = torch.empty([m], dtype=x.dtype, device=x.device)
62+
63+
# Call register_block_size to know block_size_n outside of the reduction loop.
64+
block_size_n = hl.register_block_size(n)
65+
66+
for tile_m in hl.tile(m):
67+
acc = hl.zeros([tile_m, block_size_n], dtype=x.dtype)
68+
for tile_n in hl.tile(n, block_size=block_size_n): # Reduction loop
69+
acc += x[tile_m, tile_n]
70+
out[tile_m] = acc.sum(-1)
71+
return out
72+
73+
74+
def check(m: int, n: int) -> None:
75+
from triton.testing import do_bench
76+
77+
x = torch.randn([m, n], device="cuda", dtype=torch.float32)
78+
79+
helion_out = longsum(x)
80+
torch.testing.assert_close(helion_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
81+
print("✅ Results Match ✅ naive reduction")
82+
83+
helion_red_loop_out = longsum_w_red_loop(x)
84+
torch.testing.assert_close(
85+
helion_red_loop_out, baseline_sum(x), rtol=1e-2, atol=1e-1
86+
)
87+
print("✅ Results Match ✅ Reduction Loop")
88+
89+
helion_manual_out = longsum_manual(x)
90+
torch.testing.assert_close(helion_manual_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
91+
print("✅ Results Match ✅ Manual Reduction Loop")
92+
93+
sec = do_bench(lambda: longsum(x))
94+
loop_sec = do_bench(lambda: longsum_w_red_loop(x))
95+
manual_loop_sec = do_bench(lambda: longsum_manual(x))
96+
baseline_sec = do_bench(lambda: baseline_sum(x))
97+
print(
98+
f"Helion Naive time: {sec:.4f}s, Helion Looped Time: {loop_sec:.4f}, Helion Manual Loop Time: {manual_loop_sec:.4f} torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x {baseline_sec / loop_sec:.2f}x {baseline_sec / manual_loop_sec:.2f}x"
99+
)
100+
101+
102+
def main() -> None:
103+
check(4, 130000) # seq_len = 128k
104+
105+
106+
if __name__ == "__main__":
107+
main()

0 commit comments

Comments
 (0)