Skip to content

Commit 1e6c617

Browse files
committed
Add naive reduction + looped reduction examples
1 parent 0be70e5 commit 1e6c617

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

examples/long_sum.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,30 @@
99
def baseline_sum(x: torch.Tensor) -> torch.Tensor:
1010
return x.sum(-1)
1111

12+
# Naive Reduction: Load the entire reduction dim at once, and reduce in reg.
13+
@helion.kernel(config=helion.Config(block_sizes=[[1]], reduction_loops=[None], num_warps=32, num_stages=4, indexing='block_ptr'))
14+
def longsum(x: torch.Tensor) -> torch.Tensor:
15+
m, _ = x.size()
16+
out = torch.empty(
17+
[m], dtype=x.dtype, device=x.device
18+
)
19+
20+
for tile_m in hl.tile(m):
21+
out[tile_m] = x[tile_m, :].sum(-1)
22+
return out
1223

24+
25+
# Looped reduction
1326
@helion.kernel(
1427
config=helion.Config(
1528
block_sizes=[[1]],
16-
reduction_loops=[32768], # [None] for non-looped reduction, [tile_size] for looped reduction
29+
reduction_loops=[32768], # [None] for naive reduction, [tile_size] for looped reduction
1730
num_warps=16,
1831
num_stages=5,
1932
indexing="pointer",
2033
)
2134
)
22-
def long_sum_reduction(x: torch.Tensor) -> torch.Tensor:
35+
def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor:
2336
m, _ = x.size()
2437
out = torch.empty(
2538
[m], dtype=x.dtype, device=x.device
@@ -32,7 +45,7 @@ def long_sum_reduction(x: torch.Tensor) -> torch.Tensor:
3245

3346
# This generates the same code as above, but manually implements looped reduction.
3447
@helion.kernel(config=helion.Config(block_sizes=[[32768], [1]], num_warps=16, num_stages=5, indexing='pointer'))
35-
def long_sum(x: torch.Tensor) -> torch.Tensor:
48+
def longsum_manual(x: torch.Tensor) -> torch.Tensor:
3649
m, n = x.size()
3750
out = torch.empty(
3851
[m], dtype=x.dtype, device=x.device
@@ -54,19 +67,24 @@ def check(m: int, n: int) -> None:
5467

5568
x = torch.randn([m, n], device="cuda", dtype=torch.float32)
5669

57-
helion_out = long_sum(x)
70+
helion_out = longsum(x)
5871
torch.testing.assert_close(helion_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
59-
print("✅ Results Match ✅ Naive Looped Reduction")
72+
print("✅ Results Match ✅ naive reduction")
73+
74+
helion_red_loop_out = longsum_w_red_loop(x)
75+
torch.testing.assert_close(helion_red_loop_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
76+
print("✅ Results Match ✅ Reduction Loop")
6077

61-
helion_red_out = long_sum_reduction(x)
62-
torch.testing.assert_close(helion_red_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
63-
print("✅ Results Match ✅ Reduction Helion")
78+
helion_manual_out = longsum_manual(x)
79+
torch.testing.assert_close(helion_manual_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
80+
print("✅ Results Match ✅ Manual Reduction Loop")
6481

65-
sec = do_bench(lambda: long_sum(x))
66-
red_sec = do_bench(lambda: long_sum_reduction(x))
82+
sec = do_bench(lambda: longsum(x))
83+
loop_sec = do_bench(lambda: longsum_w_red_loop(x))
84+
manual_loop_sec = do_bench(lambda: longsum_manual(x))
6785
baseline_sec = do_bench(lambda: baseline_sum(x))
6886
print(
69-
f"Helion time: {sec:.4f}s, Helion Reduction Time: {red_sec:.4f}, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x {baseline_sec / red_sec:.2f}x"
87+
f"Helion Naive time: {sec:.4f}s, Helion Looped Time: {loop_sec:.4f}, Helion Manual Loop Time: {manual_loop_sec:.4f} torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x {baseline_sec / loop_sec:.2f}x {baseline_sec / manual_loop_sec:.2f}x"
7088
)
7189

7290

0 commit comments

Comments
 (0)