Skip to content

Commit 0650950

Browse files
committed
Prototyping an hl.atomic opp
stack-info: PR: #63, branch: drisspg/stack/5
1 parent fa03be7 commit 0650950

File tree

7 files changed

+331
-4
lines changed

7 files changed

+331
-4
lines changed

helion/_compiler/device_ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,9 @@ def visit_Attribute(self, node: ast.Attribute) -> object:
696696
raise exc.CantReadOnDevice(type_info) from None
697697
return getattr(self.visit(node.value), node.attr)
698698

699+
def visit_Expr(self, node: ast.Expr) -> object:
700+
return self.visit(node.value)
701+
699702
def visit_Constant(self, node: ast.Constant) -> object:
700703
return node.value
701704

helion/_compiler/indexing_strategy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,9 @@ def create(
277277
mask_values.setdefault(f"({mask}){expand}")
278278
output_idx += 1
279279
else:
280-
raise exc.InvalidIndexingType(k)
280+
raise exc.InvalidIndexingType(type(k))
281281
assert len(output_size) == output_idx - first_non_grid_index
282282
assert len(index_values) == fake_value.ndim
283-
284283
index_expr = []
285284
for i, idx in enumerate(index_values):
286285
if fake_value.size(i) != 1:

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .loops import grid as grid
77
from .loops import register_block_size as register_block_size
88
from .loops import tile as tile
9+
from .memory_ops import atomic_add as atomic_add
910
from .memory_ops import load as load
1011
from .memory_ops import store as store
1112
from .view_ops import subscript as subscript

helion/language/_decorators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,32 @@ def __call__(self, fn: Callable[..., _T]) -> object: ...
3838

3939

4040
class APIFunc(Protocol):
41+
"""Protocol for Helion API functions that define operations within kernel code.
42+
43+
This protocol defines the interface for functions decorated with @api. These functions
44+
represent operations that can be called in Helion kernel code and are compiled
45+
into the final device code.
46+
47+
Attributes:
48+
__qualname__: The qualified name of the function.
49+
_helion_api: A literal True marker indicating this is a Helion API function.
50+
_is_device_loop: Whether this API function can transition between host and device code.
51+
When True, the function can contain both host and device code sections.
52+
_is_device_only: Whether this API function is intended for device code only.
53+
When True, the function can only be used within device code sections.
54+
_tiles_as_sizes: Whether tile indices should be converted to sizes automatically.
55+
Used primarily with tiling operations to transform indices to dimensions.
56+
_cache_type: Whether to cache the type information for repeated calls.
57+
_type_function: A callable that determines the return type of this function
58+
during type propagation phase.
59+
_codegen: A callable that generates the device code for this function.
60+
_fake_fn: A callable that provides a "fake" implementation used during
61+
tracing and compilation.
62+
_prepare_args: A callable that preprocesses the arguments before they're
63+
passed to the actual function implementation.
64+
_signature: The function signature for binding and validating arguments.
65+
"""
66+
4167
__qualname__: str
4268
_helion_api: Literal[True]
4369
# a device loop can transition between host and device code

helion/language/memory_ops.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .._compiler.inductor_lowering import CodegenState
1616

17-
__all__ = ["load", "store"]
17+
__all__ = ["atomic_add", "load", "store"]
1818

1919

2020
@has_side_effect
@@ -53,6 +53,14 @@ def _(state: CodegenState) -> ast.AST:
5353

5454
@_decorators.api(tiles_as_sizes=True)
5555
def load(tensor: torch.Tensor, index: list[object]) -> torch.Tensor:
56+
"""Load a value from a tensor using a list of indices.
57+
58+
Args:
59+
tensor: The tensor to load from
60+
index: The indices to use to index into the tensor
61+
Returns:
62+
torch.Tensor: The loaded value
63+
"""
5664
raise exc.NotInsideKernel
5765

5866

@@ -70,3 +78,86 @@ def _(state: CodegenState) -> ast.AST:
7078
return state.device_function.indexing_strategy.codegen_load(
7179
state, tensor, [*subscript]
7280
)
81+
82+
83+
@has_side_effect
84+
@_decorators.api()
85+
def atomic_add(
86+
target: torch.Tensor,
87+
index: list[object],
88+
value: torch.Tensor | float,
89+
sem: str = "relaxed",
90+
) -> None:
91+
"""
92+
Atomically add a value to a target tensor.
93+
94+
Args:
95+
target: The tensor to add to
96+
index: Indices into target for way to accumulate values
97+
value: The value to add
98+
sem: The memory ordering semantics (default: 'relaxed')
99+
100+
Returns:
101+
None
102+
"""
103+
raise exc.NotInsideKernel
104+
105+
106+
@_decorators.prepare_args(atomic_add)
107+
def _(
108+
target: torch.Tensor,
109+
index: list[object],
110+
value: torch.Tensor | float,
111+
sem: str = "relaxed",
112+
) -> tuple[torch.Tensor, object, torch.Tensor | float | int, str]:
113+
from helion._compiler.tile_index_proxy import TileIndexProxy
114+
115+
valid_sems = {"relaxed", "acquire", "release", "acq_rel"}
116+
if sem not in valid_sems:
117+
raise ValueError(
118+
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
119+
)
120+
121+
index = TileIndexProxy.prepare_index(index)
122+
index = TileIndexProxy.tiles_to_sizes(index)
123+
124+
return (target, index, value, sem)
125+
126+
127+
@_decorators.register_fake(atomic_add)
128+
def _(
129+
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
130+
) -> None:
131+
return None
132+
133+
134+
@_decorators.codegen(atomic_add)
135+
def _(state: CodegenState) -> ast.AST:
136+
import ast
137+
138+
from .._compiler.ast_extension import expr_from_string
139+
140+
target = state.proxy_arg(0)
141+
index = state.proxy_arg(1)
142+
value = state.proxy_arg(2)
143+
sem = expr_from_string(f"'{state.proxy_arg(3)}'")
144+
145+
assert isinstance(target, torch.Tensor)
146+
assert isinstance(index, (list))
147+
148+
indices = SubscriptIndexing.create(state, target, index)
149+
name = state.device_function.tensor_arg(target).name
150+
151+
value_expr = (
152+
state.ast_args[2]
153+
if isinstance(value, torch.Tensor)
154+
else ast.Constant(value=value)
155+
)
156+
assert isinstance(value_expr, ast.AST)
157+
return expr_from_string(
158+
f"tl.atomic_add({name} + offset, value, mask=mask, sem=sem)",
159+
value=value_expr,
160+
offset=indices.index_expr,
161+
mask=indices.mask_expr,
162+
sem=sem,
163+
)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,3 @@ exclude = [
8787

8888
[tool.hatch.metadata]
8989
allow-direct-references = true
90-

test/test_atomic_add.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
from expecttest import TestCase
6+
import torch
7+
8+
import helion
9+
from helion._testing import DEVICE
10+
from helion._testing import code_and_output
11+
import helion.language as hl
12+
13+
14+
@helion.kernel()
15+
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
16+
"""Test basic atomic_add functionality."""
17+
for i in hl.tile(x.size(0)):
18+
hl.atomic_add(x, [i], y[i])
19+
return x
20+
21+
22+
@helion.kernel(static_shapes=True)
23+
def atomic_add_overlap_kernel(
24+
x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor
25+
) -> torch.Tensor:
26+
"""Test atomic_add with overlapping indices."""
27+
for i in hl.tile([y.size(0)]):
28+
idx = indices[i]
29+
hl.atomic_add(x, [idx], y[i])
30+
return x
31+
32+
33+
@helion.kernel()
34+
def atomic_add_2d_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
35+
"""Test atomic_add with 2D indexing."""
36+
for i, j in hl.tile([y.size(0), y.size(1)]):
37+
hl.atomic_add(x, [i, j], y[i, j])
38+
return x
39+
40+
41+
@helion.kernel()
42+
def atomic_add_float_kernel(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
43+
"""Test atomic_add with a float constant value and reading from lookup"""
44+
for i in hl.tile(indices.size(0)):
45+
idx = indices[i]
46+
hl.atomic_add(x, [idx], 2.0)
47+
return x
48+
49+
50+
class TestAtomicOperations(TestCase):
51+
maxDiff = 16384
52+
53+
def test_basic_atomic_add(self):
54+
x = torch.zeros(10, device=DEVICE)
55+
y = torch.ones(10, device=DEVICE)
56+
args = (x, y)
57+
58+
code, result = code_and_output(
59+
atomic_add_kernel,
60+
args,
61+
block_sizes=[32],
62+
)
63+
64+
expected = torch.ones(10, device=DEVICE)
65+
torch.testing.assert_close(result, expected)
66+
self.assertExpectedInline(
67+
code,
68+
"""\
69+
from __future__ import annotations
70+
71+
import torch
72+
import triton
73+
import triton.language as tl
74+
75+
@triton.jit
76+
def _atomic_add_kernel_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
77+
pid_0 = tl.program_id(0)
78+
offset_0 = pid_0 * _BLOCK_SIZE_0
79+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
80+
mask_0 = indices_0 < x_size_0
81+
load = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
82+
tl.atomic_add(x + indices_0 * x_stride_0, load, mask=mask_0, sem='relaxed')
83+
84+
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor):
85+
\"\"\"Test basic atomic_add functionality.\"\"\"
86+
_BLOCK_SIZE_0 = 32
87+
_atomic_add_kernel_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
88+
return x
89+
90+
def _atomic_add_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor):
91+
\"\"\"Test basic atomic_add functionality.\"\"\"
92+
_BLOCK_SIZE_0 = 32
93+
from helion.runtime.precompile_shim import make_precompiler
94+
return make_precompiler(_atomic_add_kernel_kernel)(x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
95+
)
96+
97+
def test_overlapping_atomic_add(self):
98+
# Test with overlapping indices
99+
x = torch.zeros(5, device=DEVICE)
100+
y = torch.ones(10, device=DEVICE)
101+
indices = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], device=DEVICE)
102+
args = (x, y, indices)
103+
104+
code, result = code_and_output(
105+
atomic_add_overlap_kernel,
106+
args,
107+
block_sizes=[32],
108+
)
109+
110+
expected = torch.ones(5, device=DEVICE) * 2
111+
torch.testing.assert_close(result, expected)
112+
self.assertExpectedInline(
113+
code,
114+
"""\
115+
from __future__ import annotations
116+
117+
import torch
118+
import triton
119+
import triton.language as tl
120+
121+
@triton.jit
122+
def _atomic_add_overlap_kernel_kernel(indices, y, x, _BLOCK_SIZE_0: tl.constexpr):
123+
pid_0 = tl.program_id(0)
124+
offset_0 = pid_0 * _BLOCK_SIZE_0
125+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
126+
mask_0 = indices_0 < 10
127+
idx = tl.load(indices + indices_0 * 1, mask_0, other=0)
128+
load_1 = tl.load(y + indices_0 * 1, mask_0, other=0)
129+
tl.atomic_add(x + idx * 1, load_1, mask=mask_0, sem='relaxed')
130+
131+
def atomic_add_overlap_kernel(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor):
132+
\"\"\"Test atomic_add with overlapping indices.\"\"\"
133+
_BLOCK_SIZE_0 = 32
134+
_atomic_add_overlap_kernel_kernel[triton.cdiv(10, _BLOCK_SIZE_0),](indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
135+
return x
136+
137+
def _atomic_add_overlap_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor):
138+
\"\"\"Test atomic_add with overlapping indices.\"\"\"
139+
_BLOCK_SIZE_0 = 32
140+
from helion.runtime.precompile_shim import make_precompiler
141+
return make_precompiler(_atomic_add_overlap_kernel_kernel)(indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
142+
)
143+
144+
def test_2d_atomic_add(self):
145+
"""Test atomic_add with 2D tensor indexing."""
146+
x = torch.zeros(3, 4, device=DEVICE)
147+
y = torch.ones(3, 4, device=DEVICE)
148+
args = (x, y)
149+
150+
code, result = code_and_output(
151+
atomic_add_2d_kernel,
152+
args,
153+
block_sizes=[8, 8],
154+
)
155+
156+
expected = torch.ones(3, 4, device=DEVICE)
157+
torch.testing.assert_close(result, expected)
158+
self.assertIn("atomic_add", code)
159+
160+
def test_atomic_add_code_generation(self):
161+
"""Test that the generated code contains atomic_add."""
162+
x = torch.zeros(10, device=DEVICE)
163+
y = torch.ones(10, device=DEVICE)
164+
args = (x, y)
165+
166+
code, _ = code_and_output(atomic_add_kernel, args)
167+
self.assertIn("atomic_add", code)
168+
169+
def test_atomic_add_float(self):
170+
"""Test that atomic_add works with float constants."""
171+
x = torch.zeros(5, device=DEVICE, dtype=torch.float32)
172+
173+
indices = torch.tensor([0, 1, 2, 2, 3, 3, 3, 4], device=DEVICE)
174+
expected = torch.tensor(
175+
[2.0, 2.0, 4.0, 6.0, 2.0], device=DEVICE, dtype=torch.float32
176+
)
177+
178+
args = (x, indices)
179+
code, result = code_and_output(
180+
atomic_add_float_kernel,
181+
args,
182+
block_sizes=[32],
183+
)
184+
185+
torch.testing.assert_close(result, expected)
186+
187+
def test_atomic_add_invalid_sem(self):
188+
"""Test that atomic_add raises with an invalid sem value."""
189+
x = torch.zeros(10, device=DEVICE)
190+
y = torch.ones(10, device=DEVICE)
191+
192+
@helion.kernel()
193+
def bad_atomic_add_kernel(x: torch.Tensor, y: torch.Tensor):
194+
for i in hl.tile(x.size(0)):
195+
hl.atomic_add(x, [i], y[i], sem="ERROR")
196+
return x
197+
198+
with self.assertRaises(helion.exc.InternalError) as ctx:
199+
code_and_output(
200+
bad_atomic_add_kernel,
201+
(x, y),
202+
block_sizes=[32],
203+
)
204+
self.assertIn("Invalid memory semantic 'ERROR'", str(ctx.exception))
205+
206+
207+
if __name__ == "__main__":
208+
unittest.main()

0 commit comments

Comments
 (0)