Skip to content

Commit ae27afc

Browse files
committed
Add & . (ptx impl).
stack-info: PR: #189, branch: joydddd/stack/5
1 parent b819dad commit ae27afc

File tree

3 files changed

+325
-0
lines changed

3 files changed

+325
-0
lines changed

helion/language/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from .memory_ops import atomic_add as atomic_add
1111
from .memory_ops import load as load
1212
from .memory_ops import store as store
13+
from .signal_wait import signal as signal
14+
from .signal_wait import wait as wait
1315
from .tile_ops import tile_begin as tile_begin
1416
from .tile_ops import tile_block_size as tile_block_size
1517
from .tile_ops import tile_end as tile_end

helion/language/signal_wait.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
from torch.fx import has_side_effect
7+
8+
from .. import exc
9+
from . import _decorators
10+
11+
if TYPE_CHECKING:
12+
import ast
13+
14+
from .._compiler.inductor_lowering import CodegenState
15+
16+
17+
@has_side_effect
18+
@_decorators.api(tiles_as_sizes=True)
19+
def wait(
20+
signal_pad: torch.Tensor,
21+
index: list[object],
22+
signal: int = 1,
23+
update: int | None = None,
24+
op: str = "ld",
25+
sem: str = "acquire",
26+
scope: str = "sys",
27+
) -> None:
28+
"""
29+
Wait for a signal before accessing the data tensor.
30+
Args:
31+
signal_pad: The signal tensor to wait on
32+
index: Indices into signal_pad tensor for which signal to wait for
33+
signal: the signal to wait for
34+
update: update the signal_pad after acquiring the signal.
35+
sem: The memory op for acquring the lock (default: 'ld.acquire')
36+
37+
Returns:
38+
None
39+
"""
40+
raise exc.NotInsideKernel
41+
42+
43+
@_decorators.type_propagation(wait)
44+
def _(*args: object, origin: object, **kwargs: object) -> object:
45+
from .._compiler.type_propagation import NoType
46+
47+
return NoType(origin=origin)
48+
49+
50+
@_decorators.prepare_args(wait)
51+
def _(
52+
signal_pad: torch.Tensor,
53+
index: list[object],
54+
signal: int = 1,
55+
update: int | None = None,
56+
op: str = "ld",
57+
sem: str = "acquire",
58+
scope: str = "sys",
59+
) -> tuple[torch.Tensor, object, int, int | None, str, str]:
60+
from helion._compiler.tile_index_proxy import TileIndexProxy
61+
62+
print("in wait prepare_args")
63+
64+
valid_ops = {"ld", "atomic_add", "atomic_cas"}
65+
valid_sems = {"relaxed", "acquire", "acq_rel"}
66+
valid_scopes = {"sys", "gpu"}
67+
68+
if op not in valid_ops:
69+
raise ValueError(f"Invalid Wait op '{op}'. Must be one of {valid_ops}. ")
70+
71+
if sem == "release":
72+
raise ValueError(
73+
f"Do not use '{sem}' for wait patterns. Wait sem must be one of {valid_sems}."
74+
)
75+
76+
if sem not in valid_sems:
77+
raise ValueError(
78+
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
79+
)
80+
81+
if op == "atomic_cas" and not update:
82+
raise ValueError(
83+
f"{op} without an update value. Do you want to use 'ld' instead? "
84+
)
85+
86+
if op == "ld":
87+
assert update is None
88+
update = 0
89+
90+
if scope not in valid_scopes:
91+
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
92+
93+
print("index:", index)
94+
95+
index = TileIndexProxy.prepare_index(index)
96+
print("index prepare_index", index)
97+
index = TileIndexProxy.tiles_to_sizes(index)
98+
99+
print("tiles_to_sizes", index)
100+
101+
print("tiles_to_sizes", type(index[0]))
102+
103+
return (signal_pad, index, signal, update, op, sem)
104+
105+
106+
@_decorators.register_fake(wait)
107+
def _(
108+
signal_pad: torch.Tensor,
109+
index: list[object],
110+
signal: int = 1,
111+
update: int | None = None,
112+
op: str = "ld",
113+
sem: str = "acquire",
114+
scope: str = "sys",
115+
) -> None:
116+
return None
117+
118+
119+
def get_lock_spin_ptx(name: str, op: str, sem: str, scope: str):
120+
ptx_acquire_list = {
121+
"ld": f"ld.global.{sem}.{scope}.u32 $0, [$1];",
122+
"atomic_cas": f"atom.global.{sem}.{scope}.cas.b32 $0, [$1], $2, $3;",
123+
"atomic_add": f"atom.global.{sem}.{scope}.add.u32 $0, [$1], $2;",
124+
}
125+
126+
acquire_lock_expr = ptx_acquire_list[op]
127+
ptx_template = f'''
128+
tl.inline_asm_elementwise("""
129+
{{
130+
.reg .u32 %tmp32_<1>;
131+
.reg .pred %p<2>;
132+
133+
// calculate tid assuming tid.y=tid.z=1. TODO: get this from Triton
134+
mov.u32 %tmp32_0, %tid.x;
135+
setp.eq.s32 %p1, %tmp32_0, 0;
136+
137+
mov.u32 $0, 0;
138+
// initialize tmp_0 to 0
139+
wait_block:
140+
@%p1 {acquire_lock_expr}
141+
setp.ne.u32 %p0, $0, $2;
142+
and.pred %p0, %p0, %p1;
143+
@%p0 bra wait_block;
144+
bar.sync 0;
145+
}}
146+
""",
147+
"=r, l, r, r",
148+
[{name} + offset, signal, update],
149+
dtype={name}.dtype.element_ty,
150+
is_pure=False,
151+
pack=1,
152+
)
153+
'''
154+
print("ptx_template", ptx_template)
155+
return ptx_template
156+
157+
158+
@_decorators.codegen(wait)
159+
def _(state: CodegenState) -> ast.AST:
160+
import ast
161+
162+
from .._compiler.ast_extension import expr_from_string
163+
from .._compiler.indexing_strategy import SubscriptIndexing
164+
165+
signal_pad = state.proxy_arg(0)
166+
index = state.proxy_arg(1)
167+
signal = state.proxy_arg(2)
168+
update = state.proxy_arg(3)
169+
op = state.proxy_arg(4)
170+
sem = state.proxy_arg(5)
171+
scope = state.proxy_arg(6)
172+
173+
assert isinstance(signal_pad, torch.Tensor)
174+
assert isinstance(index, (list))
175+
176+
print(index, "index")
177+
indices = SubscriptIndexing.create(state, signal_pad, index)
178+
print("indices", indices)
179+
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
180+
181+
signal_expr = ast.Constant(value=signal)
182+
update_expr = ast.Constant(value=update)
183+
184+
lock_spin_ptx = get_lock_spin_ptx(signal_pad_name, op, sem, scope)
185+
186+
return expr_from_string(
187+
lock_spin_ptx,
188+
offset=indices.index_expr,
189+
signal=signal_expr,
190+
update=update_expr,
191+
)
192+
193+
194+
@has_side_effect
195+
@_decorators.api(tiles_as_sizes=True)
196+
def signal(
197+
signal_pad: torch.Tensor,
198+
index: list[object],
199+
signal: int = 1,
200+
sem: str = "ld.acquire",
201+
) -> None:
202+
"""
203+
Wait for a signal before accessing the data tensor.
204+
Args:
205+
signal_pad: The signal tensor to wait on
206+
index: Indices into signal_pad tensor for which signal to wait for
207+
signal: the signal to wait for
208+
update: update the signal_pad after acquiring the signal.
209+
sem: The memory op for acquring the lock (default: 'ld.acquire')
210+
211+
Returns:
212+
None
213+
"""
214+
raise exc.NotInsideKernel

test/test_signal_wait.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import DEVICE
7+
import helion.language as hl
8+
9+
10+
@helion.kernel(
11+
static_shapes=True,
12+
config=helion.Config(
13+
block_sizes=[64, 64], num_warps=8, num_stages=4, indexing="block_ptr"
14+
),
15+
)
16+
def wait_and_copy_kernel(x: torch.Tensor, progress: torch.Tensor) -> torch.Tensor:
17+
# TODO: call proper API to auto generate it based on tilesize & tensor shape/stride.
18+
"""Test Spinning on global memory signal pad."""
19+
m, n = x.size()
20+
# block_m = hl.register_block_size(m)
21+
# block_n = hl.register_block_size(n)
22+
23+
# print(block_m)
24+
25+
# tiles_m = (m + block_m - 1) // block_m # cdiv
26+
# tiles_n = (n + block_n - 1) // block_n # cdiv
27+
28+
print("progress size:", progress.size())
29+
progress = progress.view(-1, 128)
30+
print("progress shape", progress.size(), progress.stride())
31+
32+
out = torch.empty_like(x)
33+
for tile_m, tile_n in hl.tile([m, n]):
34+
# index_m, index_n = hl.get_tile_index([tile_m, tile_n])
35+
hl.wait(
36+
progress,
37+
[tile_m.begin, tile_n.begin],
38+
signal=1,
39+
update=None,
40+
op="ld",
41+
scope="gpu",
42+
sem="acquire",
43+
)
44+
out[tile_m, tile_n] = x[tile_m, tile_n]
45+
46+
return out
47+
48+
49+
@helion.kernel(static_shapes=True)
50+
def atomic_add_kernel(x: torch.Tensor) -> torch.Tensor:
51+
m, n = x.size()
52+
out = torch.empty_like(x)
53+
for tile_m, tile_n in hl.tile([m, n]):
54+
out[tile_m, tile_n] = x[tile_m, tile_n]
55+
hl.atomic_add(out, [tile_m, tile_n], 1)
56+
return out
57+
58+
59+
def test_tile_id():
60+
@helion.kernel(
61+
static_shapes=True,
62+
config=helion.Config(
63+
block_sizes=[
64+
16,
65+
],
66+
num_warps=8,
67+
num_stages=4,
68+
indexing="block_ptr",
69+
),
70+
)
71+
def test_tile_id_access(x: torch.Tensor) -> torch.Tensor:
72+
out = torch.zeros_like(x, dtype=torch.int32)
73+
for tile in hl.tile(x.size(0)):
74+
out[tile] = tile.id
75+
return out
76+
77+
x = torch.randn([64], device=DEVICE)
78+
result = test_tile_id_access(x)
79+
print(result)
80+
81+
82+
def test_tile_id_indexing():
83+
@helion.kernel(
84+
static_shapes=True,
85+
config=helion.Config(
86+
block_sizes=[16, 16],
87+
),
88+
)
89+
def test_tile_id_atomic_add(x: torch.Tensor) -> torch.Tensor:
90+
out = torch.zeros_like(x, dtype=torch.int32)
91+
for tile_m, tile_n in hl.tile(x.size()):
92+
hl.atomic_add(out, [tile_m.id, tile_n.id], 1)
93+
return out
94+
95+
x = torch.randn([64, 64], device=DEVICE)
96+
result = test_tile_id_atomic_add(x)
97+
print(result)
98+
99+
100+
if __name__ == "__main__":
101+
# test_tile_id()
102+
test_tile_id_indexing()
103+
# m = 4096
104+
# n = 16384
105+
# x = torch.randn([m, n], device="cuda", dtype=torch.float32)
106+
# progress = torch.zeros(4096, device="cuda", dtype=torch.int32)
107+
# wait_and_copy_kernel(x, progress)
108+
109+
# atomic_add_kernel(x)

0 commit comments

Comments
 (0)