Skip to content

Commit 84c8e31

Browse files
committed
Lint
1 parent 1e6c617 commit 84c8e31

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

examples/long_sum.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
11
from __future__ import annotations
2-
from unittest import result
32

43
import torch
54

65
import helion
76
import helion.language as hl
87

8+
99
def baseline_sum(x: torch.Tensor) -> torch.Tensor:
1010
return x.sum(-1)
1111

12+
1213
# Naive Reduction: Load the entire reduction dim at once, and reduce in reg.
13-
@helion.kernel(config=helion.Config(block_sizes=[[1]], reduction_loops=[None], num_warps=32, num_stages=4, indexing='block_ptr'))
14+
@helion.kernel(
15+
config=helion.Config(
16+
block_sizes=[[1]],
17+
reduction_loops=[None],
18+
num_warps=32,
19+
num_stages=4,
20+
indexing="block_ptr",
21+
)
22+
)
1423
def longsum(x: torch.Tensor) -> torch.Tensor:
1524
m, _ = x.size()
16-
out = torch.empty(
17-
[m], dtype=x.dtype, device=x.device
18-
)
25+
out = torch.empty([m], dtype=x.dtype, device=x.device)
1926

2027
for tile_m in hl.tile(m):
2128
out[tile_m] = x[tile_m, :].sum(-1)
@@ -26,37 +33,39 @@ def longsum(x: torch.Tensor) -> torch.Tensor:
2633
@helion.kernel(
2734
config=helion.Config(
2835
block_sizes=[[1]],
29-
reduction_loops=[32768], # [None] for naive reduction, [tile_size] for looped reduction
36+
reduction_loops=[
37+
32768
38+
], # [None] for naive reduction, [tile_size] for looped reduction
3039
num_warps=16,
3140
num_stages=5,
3241
indexing="pointer",
3342
)
3443
)
3544
def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor:
3645
m, _ = x.size()
37-
out = torch.empty(
38-
[m], dtype=x.dtype, device=x.device
39-
)
46+
out = torch.empty([m], dtype=x.dtype, device=x.device)
4047

4148
for tile_m in hl.tile(m):
4249
out[tile_m] = x[tile_m, :].sum(-1)
4350
return out
4451

4552

4653
# This generates the same code as above, but manually implements looped reduction.
47-
@helion.kernel(config=helion.Config(block_sizes=[[32768], [1]], num_warps=16, num_stages=5, indexing='pointer'))
54+
@helion.kernel(
55+
config=helion.Config(
56+
block_sizes=[[32768], [1]], num_warps=16, num_stages=5, indexing="pointer"
57+
)
58+
)
4859
def longsum_manual(x: torch.Tensor) -> torch.Tensor:
4960
m, n = x.size()
50-
out = torch.empty(
51-
[m], dtype=x.dtype, device=x.device
52-
)
61+
out = torch.empty([m], dtype=x.dtype, device=x.device)
5362

5463
# Call register_block_size to know block_size_n outside of the reduction loop.
5564
block_size_n = hl.register_block_size(n)
5665

5766
for tile_m in hl.tile(m):
5867
acc = hl.zeros([tile_m, block_size_n], dtype=x.dtype)
59-
for tile_n in hl.tile(n, block_size=block_size_n): # Reduction loop
68+
for tile_n in hl.tile(n, block_size=block_size_n): # Reduction loop
6069
acc += x[tile_m, tile_n]
6170
out[tile_m] = acc.sum(-1)
6271
return out
@@ -72,7 +81,9 @@ def check(m: int, n: int) -> None:
7281
print("✅ Results Match ✅ naive reduction")
7382

7483
helion_red_loop_out = longsum_w_red_loop(x)
75-
torch.testing.assert_close(helion_red_loop_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
84+
torch.testing.assert_close(
85+
helion_red_loop_out, baseline_sum(x), rtol=1e-2, atol=1e-1
86+
)
7687
print("✅ Results Match ✅ Reduction Loop")
7788

7889
helion_manual_out = longsum_manual(x)
@@ -89,7 +100,7 @@ def check(m: int, n: int) -> None:
89100

90101

91102
def main() -> None:
92-
check(4, 130000) # seq_len = 128k
103+
check(4, 130000) # seq_len = 128k
93104

94105

95106
if __name__ == "__main__":

0 commit comments

Comments
 (0)