Skip to content

Add hl.register_block_size and explicit tile sizes #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2025

Conversation

jansel
Copy link
Contributor

@jansel jansel commented May 12, 2025

This change makes the two pass softmax @drisspg was trying to implement possible, with some new syntax for pre-registering block sizes:

@helion.kernel(config={"block_sizes": [1, 128]})
def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
    m, n = x.size()
    out = torch.empty_like(x)
    block_size_m = hl.register_block_size(m)
    block_size_n = hl.register_block_size(n)
    for tile_m in hl.tile(m, block_size=block_size_m):
        mi = hl.full([tile_m, 1], float("-inf"), dtype=torch.float32)
        di = hl.zeros([tile_m, block_size_n], dtype=torch.float32)
        for tile_n in hl.tile(n, block_size=block_size_n):
            values = x[tile_m, tile_n]
            local_amax = torch.amax(values, dim=1, keepdim=True)
            mi_next = torch.maximum(mi, local_amax)
            di = di * torch.exp(mi - mi_next) + torch.exp(values - mi_next)
            mi = mi_next
        for tile_n in hl.tile(n, block_size=block_size_n):
            values = x[tile_m, tile_n]
            out[tile_m, tile_n] = torch.exp(values - mi) / di
    return out

The part that was hard before was referencing block_size_n before the tile_n loop.

I also considered something like:

di = hl.zeros([tile_m, hl.auto()], dtype=torch.float32)

Where the hl.auto() would be a placeholder value that we infer automatically during type propagation, though I think this version is clearer.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 12, 2025
@jansel jansel requested review from yf225 and oulgen May 12, 2025 04:18
@jansel jansel merged commit 296f09d into main May 12, 2025
5 of 7 checks passed
@jansel jansel deleted the register_block_size202505 branch May 12, 2025 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants