Skip to content

Add reduction example: Long sum #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 2, 2025
Merged

Conversation

joydddd
Copy link
Contributor

@joydddd joydddd commented May 30, 2025

No description provided.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 30, 2025
@joydddd
Copy link
Contributor Author

joydddd commented May 30, 2025

✅ Results Match ✅ Naive Helion
✅ Results Match ✅ Reduction Helion
Helion time: 0.0111s, Helion Reduction Time: 0.0104, torch time: 0.0133, speedup: 1.20x 1.27x

# Long Sum using Helion's reduction feature
# Config: reduction_loop allows Helion to generate a looped reduction (same as the naive impl above)
# Example Config:
# @helion.kernel(config=helion.Config(block_sizes=[[1]], reduction_loops=[None], num_warps=32, num_stages=4, indexing='block_ptr'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduction_loops=[None] means a reduction loop won't be used. Most likely because the autotuner found it was slower. If the goal of the example is to show how Helion can roll reductions maybe we should pick a config (like reduction_loops=[1024]) that generates a loop.

Comment on lines 14 to 29
# Looped Reduction Long Sum
# Example Config:
# @helion.kernel(config=helion.Config(block_sizes=[[32768], [1]], num_warps=16, num_stages=5, indexing='pointer'))
@helion.kernel()
def long_sum(x: torch.Tensor) -> torch.Tensor:
m, n = x.size()
out = torch.empty(
[m], dtype=x.dtype, device=x.device
)

# Call register_block_size to know block_size_n outside of the reduction loop.
block_size_n = hl.register_block_size(n)

for tile_m in hl.tile(m):
acc = hl.zeros([tile_m, block_size_n], dtype=x.dtype)
for tile_n in hl.tile(n, block_size=block_size_n): # The reduction loop for n that doesn't fit in a tile.
acc += x[tile_m, tile_n]
out[tile_m] = acc.sum(-1)
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In most cases, we want people to write reductions the second way and not this way.

  1. The second way is less code
  2. The second way has a larger search space since it can choose to do a persistent reduction. (Which it looks like the autotuner picked in this case.)

I worry people are going to be copy-and-pasting from our examples, so I don't want a "bad" example to be that prominent. We can keeping it, but we should move after the "good" example and include a clear warning that this restricts the search space and is equivalent to the first one with reduction_loop != None.

Comment on lines +28 to +49
for tile_m in hl.tile(m):
out[tile_m] = x[tile_m, :].sum(-1)
Copy link
Contributor Author

@joydddd joydddd May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helion generated looped reduction:

@triton.jit
def _long_sum_reduction_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, _, _REDUCTION_BLOCK_1: tl.constexpr):
    pid_0 = tl.program_id(0)
    offset_0 = pid_0
    indices_0 = offset_0 + tl.zeros([1], tl.int32)
    sum_1_acc = tl.full([1, _REDUCTION_BLOCK_1], 0, tl.float32)
    for roffset_1 in range(0, _, _REDUCTION_BLOCK_1):
        rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
        mask_1 = rindex_1 < _
        load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_1[None, :], other=0)
        v_0 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _REDUCTION_BLOCK_1]), load, 0)
        v_1 = sum_1_acc + v_0
        sum_1_acc = v_1
    sum_1 = tl.sum(sum_1_acc, 1)
    tl.store(out + indices_0 * out_stride_0, sum_1, None)
    

There is an additional mask step here, causing performance drop of 14%.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a PR to fix this in the next day or two.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by #109

Comment on lines 46 to 69
for tile_n in hl.tile(n, block_size=block_size_n): # Reduction loop
acc += x[tile_m, tile_n]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manual implementation translates to:

@triton.jit
def _long_sum_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, n, _BLOCK_SIZE_0: tl.constexpr):
    pid_0 = tl.program_id(0)
    offset_1 = pid_0
    indices_1 = offset_1 + tl.zeros([1], tl.int32)
    acc = tl.full([1, _BLOCK_SIZE_0], 0.0, tl.float32)
    for offset_0 in range(0, n, _BLOCK_SIZE_0):
        indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
        mask_0 = indices_0 < n
        acc_copy = acc
        load = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_0[None, :] * x_stride_1), mask_0[None, :], other=0)
        acc = acc_copy + load
    sum_1 = tl.sum(acc, 1)
    tl.store(out + indices_1 * out_stride_0, sum_1, None)

Note how masking only happens during tl.load.

@joydddd
Copy link
Contributor Author

joydddd commented May 30, 2025

✅ Results Match ✅ naive reduction
✅ Results Match ✅ Reduction Loop
✅ Results Match ✅ Manual Reduction Loop
Helion Naive time: 0.0107s, Helion Looped Time: 0.0131, Helion Manual Loop Time: 0.0111 torch time: 0.0357, speedup: 3.35x 2.73x 3.22x

@joydddd
Copy link
Contributor Author

joydddd commented Jun 1, 2025

@drisspg @jansel I don't think I have write permission to this repo. Could you merge this for me?

@jansel
Copy link
Contributor

jansel commented Jun 2, 2025

I don't think I have write permission to this repo. Could you merge this for me?

You should be able to add yourself through the fb internal page for the project.

@jansel jansel merged commit cb5ddcd into pytorch-labs:main Jun 2, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants