Skip to content

Commit 804cd24

Browse files
committed
Prototyping an hl.atomic opp
stack-info: PR: #63, branch: drisspg/stack/5
1 parent e76907d commit 804cd24

File tree

7 files changed

+160
-3
lines changed

7 files changed

+160
-3
lines changed

helion/_compiler/device_ir.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,15 @@ def visit_Attribute(self, node: ast.Attribute) -> object:
693693
raise exc.CantReadOnDevice(type_info) from None
694694
return getattr(self.visit(node.value), node.attr)
695695

696+
def visit_Expr(self, node):
697+
# Check if this is a call to a known Helion operation
698+
return self.visit(node.value)
699+
if isinstance(node.value, ast.Call) and self._is_helion_op(node.value.func):
700+
# Handle the specific operation
701+
return self.visit(node.value)
702+
# Unsupported generic expression
703+
raise exc.StatementNotSupported(type(node).__name__)
704+
696705
def visit_Constant(self, node: ast.Constant) -> object:
697706
return node.value
698707

helion/_compiler/indexing_strategy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def create(
254254
raise exc.InvalidIndexingType(k)
255255
assert len(output_size) == output_idx
256256
assert len(index_values) == fake_value.ndim
257-
258257
index_expr = []
259258
for i, idx in enumerate(index_values):
260259
if fake_value.size(i) != 1:

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .creation_ops import zeros as zeros
66
from .loops import register_block_size as register_block_size
77
from .loops import tile as tile
8+
from .memory_ops import atomic_add as atomic_add
89
from .memory_ops import load as load
910
from .memory_ops import store as store
1011
from .view_ops import subscript as subscript

helion/language/_decorators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,32 @@ def __call__(self, fn: Callable[..., _T]) -> object: ...
3838

3939

4040
class APIFunc(Protocol):
41+
"""Protocol for Helion API functions that define operations within kernel code.
42+
43+
This protocol defines the interface for functions decorated with @api. These functions
44+
represent operations that can be called in Helion kernel code and are compiled
45+
into the final device code.
46+
47+
Attributes:
48+
__qualname__: The qualified name of the function.
49+
_helion_api: A literal True marker indicating this is a Helion API function.
50+
_is_device_loop: Whether this API function can transition between host and device code.
51+
When True, the function can contain both host and device code sections.
52+
_is_device_only: Whether this API function is intended for device code only.
53+
When True, the function can only be used within device code sections.
54+
_tiles_as_sizes: Whether tile indices should be converted to sizes automatically.
55+
Used primarily with tiling operations to transform indices to dimensions.
56+
_cache_type: Whether to cache the type information for repeated calls.
57+
_type_function: A callable that determines the return type of this function
58+
during type propagation phase.
59+
_codegen: A callable that generates the device code for this function.
60+
_fake_fn: A callable that provides a "fake" implementation used during
61+
tracing and compilation.
62+
_prepare_args: A callable that preprocesses the arguments before they're
63+
passed to the actual function implementation.
64+
_signature: The function signature for binding and validating arguments.
65+
"""
66+
4167
__qualname__: str
4268
_helion_api: Literal[True]
4369
# a device loop can transition between host and device code

helion/language/memory_ops.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .._compiler.inductor_lowering import CodegenState
1616

17-
__all__ = ["load", "store"]
17+
__all__ = ["atomic_add", "load", "store"]
1818

1919

2020
@has_side_effect
@@ -53,6 +53,14 @@ def _(state: CodegenState) -> ast.AST:
5353

5454
@_decorators.api(tiles_as_sizes=True)
5555
def load(tensor: torch.Tensor, index: list[object]) -> torch.Tensor:
56+
"""Load a value from a tensor using a list of indices.
57+
58+
Args:
59+
tensor: The tensor to load from
60+
index: The indices to use to index into the tensor
61+
Returns:
62+
torch.Tensor: The loaded value
63+
"""
5664
raise exc.NotInsideKernel
5765

5866

@@ -70,3 +78,60 @@ def _(state: CodegenState) -> ast.AST:
7078
return state.device_function.indexing_strategy.codegen_load(
7179
state, tensor, [*subscript]
7280
)
81+
82+
83+
@has_side_effect
84+
@_decorators.api(tiles_as_sizes=True)
85+
def atomic_add(target: torch.Tensor, index: list[object], value: torch.Tensor) -> None:
86+
"""
87+
Atomically add a value to a target tensor.
88+
89+
Args:
90+
target (torch.Tensor): The tensor to add to
91+
value (torch.Tensor): The value to add
92+
93+
Returns:
94+
None
95+
"""
96+
raise exc.NotInsideKernel
97+
98+
99+
@_decorators.prepare_args(atomic_add)
100+
def _(
101+
target: torch.Tensor, index: object, value: torch.Tensor
102+
) -> tuple[torch.Tensor, object, torch.Tensor]:
103+
from helion._compiler.tile_index_proxy import TileIndexProxy
104+
105+
assert value.dtype == target.dtype, (
106+
f"Expected value dtype {target.dtype}, got {value.dtype}"
107+
)
108+
# Convert tile indices to proper indices
109+
if isinstance(index, (list, tuple)):
110+
index = TileIndexProxy.tiles_to_sizes(index)
111+
return (target, index, value)
112+
113+
114+
@_decorators.register_fake(atomic_add)
115+
def _(target: torch.Tensor, index: list[object], value: torch.Tensor) -> None:
116+
return None
117+
118+
119+
@_decorators.codegen(atomic_add)
120+
def _(state: CodegenState) -> ast.AST:
121+
target = state.proxy_arg(0)
122+
index = state.proxy_arg(1)
123+
value = state.proxy_arg(2)
124+
assert isinstance(target, torch.Tensor)
125+
assert isinstance(value, torch.Tensor)
126+
127+
from .._compiler.ast_extension import expr_from_string
128+
129+
indices = SubscriptIndexing.create(state, target, index)
130+
name = state.device_function.tensor_arg(target).name
131+
return expr_from_string(
132+
f"tl.atomic_add({name} + offset, value, mask=mask, sem=sem)",
133+
value=state.ast_args[2],
134+
offset=indices.index_expr,
135+
mask=indices.mask_expr,
136+
sem=expr_from_string("'relaxed'"),
137+
)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,3 @@ exclude = [
8787

8888
[tool.hatch.metadata]
8989
allow-direct-references = true
90-

test/test_atomic_add.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import code_and_output
7+
import helion.language as hl
8+
9+
10+
@helion.kernel()
11+
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
12+
"""Test basic atomic_add functionality."""
13+
for i in hl.tile([x.size(0)]):
14+
hl.atomic_add(x[i], y[i])
15+
return x
16+
17+
18+
@helion.kernel()
19+
def atomic_add_overlap_kernel(
20+
x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor
21+
) -> torch.Tensor:
22+
"""Test atomic_add with overlapping indices."""
23+
for i in hl.tile([y.size(0)]):
24+
idx = indices[i]
25+
hl.atomic_add(x[idx], y[i])
26+
return x
27+
28+
29+
def test_atomic_add():
30+
"""Test atomic add operation."""
31+
# Basic test
32+
x = torch.zeros(10, device="cuda")
33+
y = torch.ones(10, device="cuda")
34+
35+
result = atomic_add_kernel(x, y)
36+
assert torch.allclose(result, y), f"Expected {y}, got {result}"
37+
38+
# Test with overlapping indices
39+
x = torch.zeros(5, device="cuda")
40+
y = torch.ones(10, device="cuda")
41+
indices = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], device="cuda")
42+
43+
result = atomic_add_overlap_kernel(x, y, indices)
44+
expected = torch.ones(5, device="cuda") * 2
45+
assert torch.allclose(result, expected), f"Expected {expected}, got {result}"
46+
47+
48+
def test_atomic_add_code():
49+
"""Test that the atomic_add code is correctly generated."""
50+
code, _ = code_and_output(atomic_add_kernel)
51+
# Ensure "atomic_add" appears in the generated code
52+
assert "atomic_add" in code, f"Expected 'atomic_add' in generated code, got: {code}"
53+
54+
55+
if __name__ == "__main__":
56+
test_atomic_add()
57+
test_atomic_add_code()
58+
print("All tests passed!")

0 commit comments

Comments
 (0)