-
Notifications
You must be signed in to change notification settings - Fork 12
Add reduction example: Long sum #92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ Results Match ✅ Naive Helion |
examples/long_sum.py
Outdated
# Long Sum using Helion's reduction feature | ||
# Config: reduction_loop allows Helion to generate a looped reduction (same as the naive impl above) | ||
# Example Config: | ||
# @helion.kernel(config=helion.Config(block_sizes=[[1]], reduction_loops=[None], num_warps=32, num_stages=4, indexing='block_ptr')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reduction_loops=[None]
means a reduction loop won't be used. Most likely because the autotuner found it was slower. If the goal of the example is to show how Helion can roll reductions maybe we should pick a config (like reduction_loops=[1024]
) that generates a loop.
examples/long_sum.py
Outdated
# Looped Reduction Long Sum | ||
# Example Config: | ||
# @helion.kernel(config=helion.Config(block_sizes=[[32768], [1]], num_warps=16, num_stages=5, indexing='pointer')) | ||
@helion.kernel() | ||
def long_sum(x: torch.Tensor) -> torch.Tensor: | ||
m, n = x.size() | ||
out = torch.empty( | ||
[m], dtype=x.dtype, device=x.device | ||
) | ||
|
||
# Call register_block_size to know block_size_n outside of the reduction loop. | ||
block_size_n = hl.register_block_size(n) | ||
|
||
for tile_m in hl.tile(m): | ||
acc = hl.zeros([tile_m, block_size_n], dtype=x.dtype) | ||
for tile_n in hl.tile(n, block_size=block_size_n): # The reduction loop for n that doesn't fit in a tile. | ||
acc += x[tile_m, tile_n] | ||
out[tile_m] = acc.sum(-1) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In most cases, we want people to write reductions the second way and not this way.
- The second way is less code
- The second way has a larger search space since it can choose to do a persistent reduction. (Which it looks like the autotuner picked in this case.)
I worry people are going to be copy-and-pasting from our examples, so I don't want a "bad" example to be that prominent. We can keeping it, but we should move after the "good" example and include a clear warning that this restricts the search space and is equivalent to the first one with reduction_loop != None
.
for tile_m in hl.tile(m): | ||
out[tile_m] = x[tile_m, :].sum(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Helion generated looped reduction:
@triton.jit
def _long_sum_reduction_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, _, _REDUCTION_BLOCK_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
sum_1_acc = tl.full([1, _REDUCTION_BLOCK_1], 0, tl.float32)
for roffset_1 in range(0, _, _REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
mask_1 = rindex_1 < _
load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_1[None, :], other=0)
v_0 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _REDUCTION_BLOCK_1]), load, 0)
v_1 = sum_1_acc + v_0
sum_1_acc = v_1
sum_1 = tl.sum(sum_1_acc, 1)
tl.store(out + indices_0 * out_stride_0, sum_1, None)
There is an additional mask step here, causing performance drop of 14%.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a PR to fix this in the next day or two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed by #109
examples/long_sum.py
Outdated
for tile_n in hl.tile(n, block_size=block_size_n): # Reduction loop | ||
acc += x[tile_m, tile_n] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The manual implementation translates to:
@triton.jit
def _long_sum_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, n, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_1 = pid_0
indices_1 = offset_1 + tl.zeros([1], tl.int32)
acc = tl.full([1, _BLOCK_SIZE_0], 0.0, tl.float32)
for offset_0 in range(0, n, _BLOCK_SIZE_0):
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < n
acc_copy = acc
load = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_0[None, :] * x_stride_1), mask_0[None, :], other=0)
acc = acc_copy + load
sum_1 = tl.sum(acc, 1)
tl.store(out + indices_1 * out_stride_0, sum_1, None)
Note how masking only happens during tl.load
.
✅ Results Match ✅ naive reduction |
You should be able to add yourself through the fb internal page for the project. |
No description provided.