Skip to content

Prototyping an hl.atomic opp #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,9 @@ def visit_Attribute(self, node: ast.Attribute) -> object:
raise exc.CantReadOnDevice(type_info) from None
return getattr(self.visit(node.value), node.attr)

def visit_Expr(self, node: ast.Expr) -> object:
return self.visit(node.value)

def visit_Constant(self, node: ast.Constant) -> object:
return node.value

Expand Down
3 changes: 1 addition & 2 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,9 @@ def create(
mask_values.setdefault(f"({mask}){expand}")
output_idx += 1
else:
raise exc.InvalidIndexingType(k)
raise exc.InvalidIndexingType(type(k))
assert len(output_size) == output_idx - first_non_grid_index
assert len(index_values) == fake_value.ndim

index_expr = []
for i, idx in enumerate(index_values):
if fake_value.size(i) != 1:
Expand Down
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .loops import grid as grid
from .loops import register_block_size as register_block_size
from .loops import tile as tile
from .memory_ops import atomic_add as atomic_add
from .memory_ops import load as load
from .memory_ops import store as store
from .view_ops import subscript as subscript
26 changes: 26 additions & 0 deletions helion/language/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ def __call__(self, fn: Callable[..., _T]) -> object: ...


class APIFunc(Protocol):
"""Protocol for Helion API functions that define operations within kernel code.

This protocol defines the interface for functions decorated with @api. These functions
represent operations that can be called in Helion kernel code and are compiled
into the final device code.

Attributes:
__qualname__: The qualified name of the function.
_helion_api: A literal True marker indicating this is a Helion API function.
_is_device_loop: Whether this API function can transition between host and device code.
When True, the function can contain both host and device code sections.
_is_device_only: Whether this API function is intended for device code only.
When True, the function can only be used within device code sections.
_tiles_as_sizes: Whether tile indices should be converted to sizes automatically.
Used primarily with tiling operations to transform indices to dimensions.
_cache_type: Whether to cache the type information for repeated calls.
_type_function: A callable that determines the return type of this function
during type propagation phase.
_codegen: A callable that generates the device code for this function.
_fake_fn: A callable that provides a "fake" implementation used during
tracing and compilation.
_prepare_args: A callable that preprocesses the arguments before they're
passed to the actual function implementation.
_signature: The function signature for binding and validating arguments.
"""

__qualname__: str
_helion_api: Literal[True]
# a device loop can transition between host and device code
Expand Down
93 changes: 92 additions & 1 deletion helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .._compiler.inductor_lowering import CodegenState

__all__ = ["load", "store"]
__all__ = ["atomic_add", "load", "store"]


@has_side_effect
Expand Down Expand Up @@ -53,6 +53,14 @@ def _(state: CodegenState) -> ast.AST:

@_decorators.api(tiles_as_sizes=True)
def load(tensor: torch.Tensor, index: list[object]) -> torch.Tensor:
"""Load a value from a tensor using a list of indices.

Args:
tensor: The tensor to load from
index: The indices to use to index into the tensor
Returns:
torch.Tensor: The loaded value
"""
raise exc.NotInsideKernel


Expand All @@ -70,3 +78,86 @@ def _(state: CodegenState) -> ast.AST:
return state.device_function.indexing_strategy.codegen_load(
state, tensor, [*subscript]
)


@has_side_effect
@_decorators.api()
def atomic_add(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> None:
"""
Atomically add a value to a target tensor.

Args:
target: The tensor to add to
index: Indices into target for way to accumulate values
value: The value to add
sem: The memory ordering semantics (default: 'relaxed')

Returns:
None
"""
raise exc.NotInsideKernel


@_decorators.prepare_args(atomic_add)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> tuple[torch.Tensor, object, torch.Tensor | float | int, str]:
from helion._compiler.tile_index_proxy import TileIndexProxy

valid_sems = {"relaxed", "acquire", "release", "acq_rel"}
if sem not in valid_sems:
raise ValueError(
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
)

index = TileIndexProxy.prepare_index(index)
index = TileIndexProxy.tiles_to_sizes(index)

return (target, index, value, sem)


@_decorators.register_fake(atomic_add)
def _(
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
) -> None:
return None


@_decorators.codegen(atomic_add)
def _(state: CodegenState) -> ast.AST:
import ast

from .._compiler.ast_extension import expr_from_string

target = state.proxy_arg(0)
index = state.proxy_arg(1)
value = state.proxy_arg(2)
sem = expr_from_string(f"'{state.proxy_arg(3)}'")

assert isinstance(target, torch.Tensor)
assert isinstance(index, (list))

indices = SubscriptIndexing.create(state, target, index)
name = state.device_function.tensor_arg(target).name

value_expr = (
state.ast_args[2]
if isinstance(value, torch.Tensor)
else ast.Constant(value=value)
)
assert isinstance(value_expr, ast.AST)
return expr_from_string(
f"tl.atomic_add({name} + offset, value, mask=mask, sem=sem)",
value=value_expr,
offset=indices.index_expr,
mask=indices.mask_expr,
sem=sem,
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,3 @@ exclude = [

[tool.hatch.metadata]
allow-direct-references = true

208 changes: 208 additions & 0 deletions test/test_atomic_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from __future__ import annotations

import unittest

from expecttest import TestCase
import torch

import helion
from helion._testing import DEVICE
from helion._testing import code_and_output
import helion.language as hl


@helion.kernel()
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Test basic atomic_add functionality."""
for i in hl.tile(x.size(0)):
hl.atomic_add(x, [i], y[i])
return x


@helion.kernel(static_shapes=True)
def atomic_add_overlap_kernel(
x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor
) -> torch.Tensor:
"""Test atomic_add with overlapping indices."""
for i in hl.tile([y.size(0)]):
idx = indices[i]
hl.atomic_add(x, [idx], y[i])
return x


@helion.kernel()
def atomic_add_2d_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Test atomic_add with 2D indexing."""
for i, j in hl.tile([y.size(0), y.size(1)]):
hl.atomic_add(x, [i, j], y[i, j])
return x


@helion.kernel()
def atomic_add_float_kernel(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""Test atomic_add with a float constant value and reading from lookup"""
for i in hl.tile(indices.size(0)):
idx = indices[i]
hl.atomic_add(x, [idx], 2.0)
return x


class TestAtomicOperations(TestCase):
maxDiff = 16384

def test_basic_atomic_add(self):
x = torch.zeros(10, device=DEVICE)
y = torch.ones(10, device=DEVICE)
args = (x, y)

code, result = code_and_output(
atomic_add_kernel,
args,
block_sizes=[32],
)

expected = torch.ones(10, device=DEVICE)
torch.testing.assert_close(result, expected)
self.assertExpectedInline(
code,
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _atomic_add_kernel_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
tl.atomic_add(x + indices_0 * x_stride_0, load, mask=mask_0, sem='relaxed')

def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor):
\"\"\"Test basic atomic_add functionality.\"\"\"
_BLOCK_SIZE_0 = 32
_atomic_add_kernel_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return x

def _atomic_add_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor):
\"\"\"Test basic atomic_add functionality.\"\"\"
_BLOCK_SIZE_0 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_atomic_add_kernel_kernel)(x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)

def test_overlapping_atomic_add(self):
# Test with overlapping indices
x = torch.zeros(5, device=DEVICE)
y = torch.ones(10, device=DEVICE)
indices = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], device=DEVICE)
args = (x, y, indices)

code, result = code_and_output(
atomic_add_overlap_kernel,
args,
block_sizes=[32],
)

expected = torch.ones(5, device=DEVICE) * 2
torch.testing.assert_close(result, expected)
self.assertExpectedInline(
code,
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _atomic_add_overlap_kernel_kernel(indices, y, x, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < 10
idx = tl.load(indices + indices_0 * 1, mask_0, other=0)
load_1 = tl.load(y + indices_0 * 1, mask_0, other=0)
tl.atomic_add(x + idx * 1, load_1, mask=mask_0, sem='relaxed')

def atomic_add_overlap_kernel(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor):
\"\"\"Test atomic_add with overlapping indices.\"\"\"
_BLOCK_SIZE_0 = 32
_atomic_add_overlap_kernel_kernel[triton.cdiv(10, _BLOCK_SIZE_0),](indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return x

def _atomic_add_overlap_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor):
\"\"\"Test atomic_add with overlapping indices.\"\"\"
_BLOCK_SIZE_0 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_atomic_add_overlap_kernel_kernel)(indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)

def test_2d_atomic_add(self):
"""Test atomic_add with 2D tensor indexing."""
x = torch.zeros(3, 4, device=DEVICE)
y = torch.ones(3, 4, device=DEVICE)
args = (x, y)

code, result = code_and_output(
atomic_add_2d_kernel,
args,
block_sizes=[8, 8],
)

expected = torch.ones(3, 4, device=DEVICE)
torch.testing.assert_close(result, expected)
self.assertIn("atomic_add", code)

def test_atomic_add_code_generation(self):
"""Test that the generated code contains atomic_add."""
x = torch.zeros(10, device=DEVICE)
y = torch.ones(10, device=DEVICE)
args = (x, y)

code, _ = code_and_output(atomic_add_kernel, args)
self.assertIn("atomic_add", code)

def test_atomic_add_float(self):
"""Test that atomic_add works with float constants."""
x = torch.zeros(5, device=DEVICE, dtype=torch.float32)

indices = torch.tensor([0, 1, 2, 2, 3, 3, 3, 4], device=DEVICE)
expected = torch.tensor(
[2.0, 2.0, 4.0, 6.0, 2.0], device=DEVICE, dtype=torch.float32
)

args = (x, indices)
code, result = code_and_output(
atomic_add_float_kernel,
args,
block_sizes=[32],
)

torch.testing.assert_close(result, expected)

def test_atomic_add_invalid_sem(self):
"""Test that atomic_add raises with an invalid sem value."""
x = torch.zeros(10, device=DEVICE)
y = torch.ones(10, device=DEVICE)

@helion.kernel()
def bad_atomic_add_kernel(x: torch.Tensor, y: torch.Tensor):
for i in hl.tile(x.size(0)):
hl.atomic_add(x, [i], y[i], sem="ERROR")
return x

with self.assertRaises(helion.exc.InternalError) as ctx:
code_and_output(
bad_atomic_add_kernel,
(x, y),
block_sizes=[32],
)
self.assertIn("Invalid memory semantic 'ERROR'", str(ctx.exception))


if __name__ == "__main__":
unittest.main()
Loading