Skip to content

Commit 0fb0b67

Browse files
committed
Prototyping an hl.atomic opp
stack-info: PR: #63, branch: drisspg/stack/5
1 parent e76907d commit 0fb0b67

File tree

7 files changed

+254
-3
lines changed

7 files changed

+254
-3
lines changed

helion/_compiler/device_ir.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,25 @@ def visit_Attribute(self, node: ast.Attribute) -> object:
693693
raise exc.CantReadOnDevice(type_info) from None
694694
return getattr(self.visit(node.value), node.attr)
695695

696+
def visit_Expr(self, node):
697+
approved_ops = ["atomic_add"]
698+
699+
# Check if there is an inner call to an approved op
700+
if isinstance(node.value, ast.Call):
701+
func = node.value.func
702+
op_name = None
703+
704+
if isinstance(func, ast.Name):
705+
op_name = func.id
706+
707+
elif isinstance(func, ast.Attribute):
708+
op_name = func.attr
709+
710+
if op_name in approved_ops:
711+
return self.visit(node.value)
712+
713+
raise exc.StatementNotSupported(type(node).__name__)
714+
696715
def visit_Constant(self, node: ast.Constant) -> object:
697716
return node.value
698717

helion/_compiler/indexing_strategy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def create(
254254
raise exc.InvalidIndexingType(k)
255255
assert len(output_size) == output_idx
256256
assert len(index_values) == fake_value.ndim
257-
258257
index_expr = []
259258
for i, idx in enumerate(index_values):
260259
if fake_value.size(i) != 1:

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .creation_ops import zeros as zeros
66
from .loops import register_block_size as register_block_size
77
from .loops import tile as tile
8+
from .memory_ops import atomic_add as atomic_add
89
from .memory_ops import load as load
910
from .memory_ops import store as store
1011
from .view_ops import subscript as subscript

helion/language/_decorators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,32 @@ def __call__(self, fn: Callable[..., _T]) -> object: ...
3838

3939

4040
class APIFunc(Protocol):
41+
"""Protocol for Helion API functions that define operations within kernel code.
42+
43+
This protocol defines the interface for functions decorated with @api. These functions
44+
represent operations that can be called in Helion kernel code and are compiled
45+
into the final device code.
46+
47+
Attributes:
48+
__qualname__: The qualified name of the function.
49+
_helion_api: A literal True marker indicating this is a Helion API function.
50+
_is_device_loop: Whether this API function can transition between host and device code.
51+
When True, the function can contain both host and device code sections.
52+
_is_device_only: Whether this API function is intended for device code only.
53+
When True, the function can only be used within device code sections.
54+
_tiles_as_sizes: Whether tile indices should be converted to sizes automatically.
55+
Used primarily with tiling operations to transform indices to dimensions.
56+
_cache_type: Whether to cache the type information for repeated calls.
57+
_type_function: A callable that determines the return type of this function
58+
during type propagation phase.
59+
_codegen: A callable that generates the device code for this function.
60+
_fake_fn: A callable that provides a "fake" implementation used during
61+
tracing and compilation.
62+
_prepare_args: A callable that preprocesses the arguments before they're
63+
passed to the actual function implementation.
64+
_signature: The function signature for binding and validating arguments.
65+
"""
66+
4167
__qualname__: str
4268
_helion_api: Literal[True]
4369
# a device loop can transition between host and device code

helion/language/memory_ops.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .._compiler.inductor_lowering import CodegenState
1616

17-
__all__ = ["load", "store"]
17+
__all__ = ["atomic_add", "load", "store"]
1818

1919

2020
@has_side_effect
@@ -53,6 +53,14 @@ def _(state: CodegenState) -> ast.AST:
5353

5454
@_decorators.api(tiles_as_sizes=True)
5555
def load(tensor: torch.Tensor, index: list[object]) -> torch.Tensor:
56+
"""Load a value from a tensor using a list of indices.
57+
58+
Args:
59+
tensor: The tensor to load from
60+
index: The indices to use to index into the tensor
61+
Returns:
62+
torch.Tensor: The loaded value
63+
"""
5664
raise exc.NotInsideKernel
5765

5866

@@ -70,3 +78,60 @@ def _(state: CodegenState) -> ast.AST:
7078
return state.device_function.indexing_strategy.codegen_load(
7179
state, tensor, [*subscript]
7280
)
81+
82+
83+
@has_side_effect
84+
@_decorators.api(tiles_as_sizes=True)
85+
def atomic_add(target: torch.Tensor, index: list[object], value: torch.Tensor) -> None:
86+
"""
87+
Atomically add a value to a target tensor.
88+
89+
Args:
90+
target: The tensor to add to
91+
index: Indices into target for way to accumulate values
92+
value: The value to add
93+
94+
Returns:
95+
None
96+
"""
97+
raise exc.NotInsideKernel
98+
99+
100+
@_decorators.prepare_args(atomic_add)
101+
def _(
102+
target: torch.Tensor, index: list[object], value: torch.Tensor
103+
) -> tuple[torch.Tensor, object, torch.Tensor]:
104+
from helion._compiler.tile_index_proxy import TileIndexProxy
105+
106+
assert value.dtype == target.dtype, (
107+
f"Expected value dtype {target.dtype}, got {value.dtype}"
108+
)
109+
index = TileIndexProxy.prepare_index(index)
110+
index = TileIndexProxy.tiles_to_sizes(index)
111+
return (target, index, value)
112+
113+
114+
@_decorators.register_fake(atomic_add)
115+
def _(target: torch.Tensor, index: list[object], value: torch.Tensor) -> None:
116+
return None
117+
118+
119+
@_decorators.codegen(atomic_add)
120+
def _(state: CodegenState) -> ast.AST:
121+
from .._compiler.ast_extension import expr_from_string
122+
123+
target = state.proxy_arg(0)
124+
index = state.proxy_arg(1)
125+
value = state.proxy_arg(2)
126+
assert isinstance(target, torch.Tensor)
127+
assert isinstance(value, torch.Tensor)
128+
129+
indices = SubscriptIndexing.create(state, target, index)
130+
name = state.device_function.tensor_arg(target).name
131+
return expr_from_string(
132+
f"tl.atomic_add({name} + offset, value, mask=mask, sem=sem)",
133+
value=state.ast_args[2],
134+
offset=indices.index_expr,
135+
mask=indices.mask_expr,
136+
sem=expr_from_string("'relaxed'"),
137+
)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,3 @@ exclude = [
8787

8888
[tool.hatch.metadata]
8989
allow-direct-references = true
90-

test/test_atomic_add.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
from expecttest import TestCase
6+
import torch
7+
8+
import helion
9+
from helion._testing import DEVICE
10+
from helion._testing import code_and_output
11+
import helion.language as hl
12+
13+
14+
@helion.kernel()
15+
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
16+
"""Test basic atomic_add functionality."""
17+
for i in hl.tile([x.size(0)]):
18+
hl.atomic_add(x, [i], y[i])
19+
return x
20+
21+
22+
@helion.kernel()
23+
def atomic_add_overlap_kernel(
24+
x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor
25+
) -> torch.Tensor:
26+
"""Test atomic_add with overlapping indices."""
27+
for i in hl.tile([y.size(0)]):
28+
idx = indices[i]
29+
hl.atomic_add(x, [idx], y[i])
30+
return x
31+
32+
33+
class TestAtomicOperations(TestCase):
34+
maxDiff = 16384
35+
36+
def test_basic_atomic_add(self):
37+
# Basic test with sequential indices
38+
x = torch.zeros(10, device=DEVICE)
39+
y = torch.ones(10, device=DEVICE)
40+
args = (x, y)
41+
42+
code, result = code_and_output(
43+
atomic_add_kernel,
44+
args,
45+
block_sizes=[32],
46+
)
47+
48+
expected = torch.ones(10, device=DEVICE)
49+
torch.testing.assert_close(result, expected)
50+
self.assertExpectedInline(
51+
code,
52+
"""\
53+
from __future__ import annotations
54+
55+
import torch
56+
import triton
57+
import triton.language as tl
58+
59+
@triton.jit
60+
def _atomic_add_kernel_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
61+
pid_0 = tl.program_id(0)
62+
offset_0 = pid_0 * _BLOCK_SIZE_0
63+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
64+
mask_0 = indices_0 < x_size_0
65+
load = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
66+
tl.atomic_add(x + indices_0 * x_stride_0, load, mask=mask_0, sem='relaxed')
67+
68+
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor):
69+
\"\"\"Test basic atomic_add functionality.\"\"\"
70+
_BLOCK_SIZE_0 = 32
71+
_atomic_add_kernel_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
72+
return x
73+
74+
def _atomic_add_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor):
75+
\"\"\"Test basic atomic_add functionality.\"\"\"
76+
_BLOCK_SIZE_0 = 32
77+
from helion.runtime.precompile_shim import make_precompiler
78+
return make_precompiler(_atomic_add_kernel_kernel)(x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
79+
)
80+
81+
def test_overlapping_atomic_add(self):
82+
# Test with overlapping indices
83+
x = torch.zeros(5, device=DEVICE)
84+
y = torch.ones(10, device=DEVICE)
85+
indices = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], device=DEVICE)
86+
args = (x, y, indices)
87+
88+
code, result = code_and_output(
89+
atomic_add_overlap_kernel,
90+
args,
91+
block_sizes=[32],
92+
)
93+
94+
expected = torch.ones(5, device=DEVICE) * 2
95+
torch.testing.assert_close(result, expected)
96+
self.assertExpectedInline(
97+
code,
98+
"""\
99+
from __future__ import annotations
100+
101+
import torch
102+
import triton
103+
import triton.language as tl
104+
105+
@triton.jit
106+
def _atomic_add_overlap_kernel_kernel(y, indices, x, y_size_0, indices_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
107+
pid_0 = tl.program_id(0)
108+
offset_0 = pid_0 * _BLOCK_SIZE_0
109+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
110+
mask_0 = indices_0 < y_size_0
111+
idx = tl.load(indices + indices_0 * indices_stride_0, mask_0, other=0)
112+
load_1 = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
113+
tl.atomic_add(x + idx * x_stride_0, load_1, mask=mask_0, sem='relaxed')
114+
115+
def atomic_add_overlap_kernel(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor):
116+
\"\"\"Test atomic_add with overlapping indices.\"\"\"
117+
_BLOCK_SIZE_0 = 32
118+
_atomic_add_overlap_kernel_kernel[triton.cdiv(y.size(0), _BLOCK_SIZE_0),](y, indices, x, y.size(0), indices.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
119+
return x
120+
121+
def _atomic_add_overlap_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor):
122+
\"\"\"Test atomic_add with overlapping indices.\"\"\"
123+
_BLOCK_SIZE_0 = 32
124+
from helion.runtime.precompile_shim import make_precompiler
125+
return make_precompiler(_atomic_add_overlap_kernel_kernel)(y, indices, x, y.size(0), indices.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
126+
)
127+
128+
def test_atomic_add_code_generation(self):
129+
"""Test that the generated code contains atomic_add."""
130+
x = torch.zeros(10, device=DEVICE)
131+
y = torch.ones(10, device=DEVICE)
132+
args = (x, y)
133+
134+
code, _ = code_and_output(atomic_add_kernel, args)
135+
# Verify atomic_add appears in the generated code
136+
self.assertIn("atomic_add", code)
137+
# Verify the new signature format (using the target tensor and indices list)
138+
self.assertIn("tl.atomic_add(x + offset_0", code)
139+
140+
141+
if __name__ == "__main__":
142+
unittest.main()

0 commit comments

Comments
 (0)