Skip to content

Commit bf8c898

Browse files
committed
Prototyping an hl.atomic opp
stack-info: PR: #63, branch: drisspg/stack/5
1 parent fa03be7 commit bf8c898

File tree

7 files changed

+395
-4
lines changed

7 files changed

+395
-4
lines changed

helion/_compiler/device_ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,9 @@ def visit_Attribute(self, node: ast.Attribute) -> object:
696696
raise exc.CantReadOnDevice(type_info) from None
697697
return getattr(self.visit(node.value), node.attr)
698698

699+
def visit_Expr(self, node):
700+
return self.visit(node.value)
701+
699702
def visit_Constant(self, node: ast.Constant) -> object:
700703
return node.value
701704

helion/_compiler/indexing_strategy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,9 @@ def create(
277277
mask_values.setdefault(f"({mask}){expand}")
278278
output_idx += 1
279279
else:
280-
raise exc.InvalidIndexingType(k)
280+
raise exc.InvalidIndexingType(type(k))
281281
assert len(output_size) == output_idx - first_non_grid_index
282282
assert len(index_values) == fake_value.ndim
283-
284283
index_expr = []
285284
for i, idx in enumerate(index_values):
286285
if fake_value.size(i) != 1:

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .loops import grid as grid
77
from .loops import register_block_size as register_block_size
88
from .loops import tile as tile
9+
from .memory_ops import atomic_add as atomic_add
910
from .memory_ops import load as load
1011
from .memory_ops import store as store
1112
from .view_ops import subscript as subscript

helion/language/_decorators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,32 @@ def __call__(self, fn: Callable[..., _T]) -> object: ...
3838

3939

4040
class APIFunc(Protocol):
41+
"""Protocol for Helion API functions that define operations within kernel code.
42+
43+
This protocol defines the interface for functions decorated with @api. These functions
44+
represent operations that can be called in Helion kernel code and are compiled
45+
into the final device code.
46+
47+
Attributes:
48+
__qualname__: The qualified name of the function.
49+
_helion_api: A literal True marker indicating this is a Helion API function.
50+
_is_device_loop: Whether this API function can transition between host and device code.
51+
When True, the function can contain both host and device code sections.
52+
_is_device_only: Whether this API function is intended for device code only.
53+
When True, the function can only be used within device code sections.
54+
_tiles_as_sizes: Whether tile indices should be converted to sizes automatically.
55+
Used primarily with tiling operations to transform indices to dimensions.
56+
_cache_type: Whether to cache the type information for repeated calls.
57+
_type_function: A callable that determines the return type of this function
58+
during type propagation phase.
59+
_codegen: A callable that generates the device code for this function.
60+
_fake_fn: A callable that provides a "fake" implementation used during
61+
tracing and compilation.
62+
_prepare_args: A callable that preprocesses the arguments before they're
63+
passed to the actual function implementation.
64+
_signature: The function signature for binding and validating arguments.
65+
"""
66+
4167
__qualname__: str
4268
_helion_api: Literal[True]
4369
# a device loop can transition between host and device code

helion/language/memory_ops.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .._compiler.inductor_lowering import CodegenState
1616

17-
__all__ = ["load", "store"]
17+
__all__ = ["atomic_add", "load", "store"]
1818

1919

2020
@has_side_effect
@@ -53,6 +53,14 @@ def _(state: CodegenState) -> ast.AST:
5353

5454
@_decorators.api(tiles_as_sizes=True)
5555
def load(tensor: torch.Tensor, index: list[object]) -> torch.Tensor:
56+
"""Load a value from a tensor using a list of indices.
57+
58+
Args:
59+
tensor: The tensor to load from
60+
index: The indices to use to index into the tensor
61+
Returns:
62+
torch.Tensor: The loaded value
63+
"""
5664
raise exc.NotInsideKernel
5765

5866

@@ -70,3 +78,85 @@ def _(state: CodegenState) -> ast.AST:
7078
return state.device_function.indexing_strategy.codegen_load(
7179
state, tensor, [*subscript]
7280
)
81+
82+
83+
@has_side_effect
84+
@_decorators.api()
85+
def atomic_add(
86+
target: torch.Tensor,
87+
index: list[object],
88+
value: torch.Tensor | float,
89+
sem: str = "relaxed",
90+
) -> None:
91+
"""
92+
Atomically add a value to a target tensor.
93+
94+
Args:
95+
target: The tensor to add to
96+
index: Indices into target for way to accumulate values
97+
value: The value to add
98+
sem: The memory ordering semantics (default: 'relaxed')
99+
100+
Returns:
101+
None
102+
"""
103+
raise exc.NotInsideKernel
104+
105+
106+
@_decorators.prepare_args(atomic_add)
107+
def _(
108+
target: torch.Tensor,
109+
index: list[object],
110+
value: torch.Tensor | float,
111+
sem: str = "relaxed",
112+
) -> tuple[torch.Tensor, object, torch.Tensor | float | int, str]:
113+
from helion._compiler.tile_index_proxy import TileIndexProxy
114+
115+
valid_sems = {"relaxed", "acquire", "release", "acq_rel"}
116+
if sem not in valid_sems:
117+
raise ValueError(
118+
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
119+
)
120+
121+
index = TileIndexProxy.prepare_index(index)
122+
index = TileIndexProxy.tiles_to_sizes(index)
123+
124+
return (target, index, value, sem)
125+
126+
127+
@_decorators.register_fake(atomic_add)
128+
def _(
129+
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
130+
) -> None:
131+
return None
132+
133+
134+
@_decorators.codegen(atomic_add)
135+
def _(state: CodegenState) -> ast.AST:
136+
import ast
137+
138+
from .._compiler.ast_extension import expr_from_string
139+
140+
target = state.proxy_arg(0)
141+
index = state.proxy_arg(1)
142+
value = state.proxy_arg(2)
143+
sem = expr_from_string(f"'{state.proxy_arg(3)}'")
144+
145+
assert isinstance(target, torch.Tensor)
146+
147+
indices = SubscriptIndexing.create(state, target, index)
148+
name = state.device_function.tensor_arg(target).name
149+
150+
value_expr = (
151+
state.ast_args[2]
152+
if isinstance(value, torch.Tensor)
153+
else ast.Constant(value=value)
154+
)
155+
156+
return expr_from_string(
157+
f"tl.atomic_add({name} + offset, value, mask=mask, sem=sem)",
158+
value=value_expr,
159+
offset=indices.index_expr,
160+
mask=indices.mask_expr,
161+
sem=sem,
162+
)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,3 @@ exclude = [
8787

8888
[tool.hatch.metadata]
8989
allow-direct-references = true
90-

0 commit comments

Comments
 (0)