|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import TYPE_CHECKING |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch.fx import has_side_effect |
| 7 | + |
| 8 | +from .. import exc |
| 9 | +from . import _decorators |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + import ast |
| 13 | + |
| 14 | + from .._compiler.inductor_lowering import CodegenState |
| 15 | + |
| 16 | + |
| 17 | +@has_side_effect |
| 18 | +@_decorators.api(tiles_as_sizes=True) |
| 19 | +def wait( |
| 20 | + signal_pad: torch.Tensor, |
| 21 | + index: list[object], |
| 22 | + signal: int = 1, |
| 23 | + update: int | None = None, |
| 24 | + op: str = "ld", |
| 25 | + sem: str = "acquire", |
| 26 | + scope: str = "sys", |
| 27 | +) -> None: |
| 28 | + """ |
| 29 | + Wait for a signal before accessing the data tensor. |
| 30 | + Args: |
| 31 | + signal_pad: The signal tensor to wait on |
| 32 | + index: Indices into signal_pad tensor for which signal to wait for |
| 33 | + signal: the signal to wait for |
| 34 | + update: update the signal_pad after acquiring the signal. |
| 35 | + sem: The memory op for acquring the lock (default: 'ld.acquire') |
| 36 | +
|
| 37 | + Returns: |
| 38 | + None |
| 39 | + """ |
| 40 | + raise exc.NotInsideKernel |
| 41 | + |
| 42 | + |
| 43 | +@_decorators.type_propagation(wait) |
| 44 | +def _(*args: object, origin: object, **kwargs: object) -> object: |
| 45 | + from .._compiler.type_propagation import NoType |
| 46 | + |
| 47 | + return NoType(origin=origin) |
| 48 | + |
| 49 | + |
| 50 | +@_decorators.prepare_args(wait) |
| 51 | +def _( |
| 52 | + signal_pad: torch.Tensor, |
| 53 | + index: list[object], |
| 54 | + signal: int = 1, |
| 55 | + update: int | None = None, |
| 56 | + op: str = "ld", |
| 57 | + sem: str = "acquire", |
| 58 | + scope: str = "sys", |
| 59 | +) -> tuple[torch.Tensor, object, int, int | None, str, str]: |
| 60 | + from helion._compiler.tile_index_proxy import TileIndexProxy |
| 61 | + |
| 62 | + print("in wait prepare_args") |
| 63 | + |
| 64 | + valid_ops = {"ld", "atomic_add", "atomic_cas"} |
| 65 | + valid_sems = {"relaxed", "acquire", "acq_rel"} |
| 66 | + valid_scopes = {"sys", "gpu"} |
| 67 | + |
| 68 | + if op not in valid_ops: |
| 69 | + raise ValueError(f"Invalid Wait op '{op}'. Must be one of {valid_ops}. ") |
| 70 | + |
| 71 | + if sem == "release": |
| 72 | + raise ValueError( |
| 73 | + f"Do not use '{sem}' for wait patterns. Wait sem must be one of {valid_sems}." |
| 74 | + ) |
| 75 | + |
| 76 | + if sem not in valid_sems: |
| 77 | + raise ValueError( |
| 78 | + f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." |
| 79 | + ) |
| 80 | + |
| 81 | + if op == "atomic_cas" and not update: |
| 82 | + raise ValueError( |
| 83 | + f"{op} without an update value. Do you want to use 'ld' instead? " |
| 84 | + ) |
| 85 | + |
| 86 | + if op == "ld": |
| 87 | + assert update is None |
| 88 | + update = 0 |
| 89 | + |
| 90 | + if scope not in valid_scopes: |
| 91 | + raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") |
| 92 | + |
| 93 | + print("index:", index) |
| 94 | + |
| 95 | + index = TileIndexProxy.prepare_index(index) |
| 96 | + print("index prepare_index", index) |
| 97 | + index = TileIndexProxy.tiles_to_sizes(index) |
| 98 | + |
| 99 | + print("tiles_to_sizes", index) |
| 100 | + |
| 101 | + print("tiles_to_sizes", type(index[0])) |
| 102 | + |
| 103 | + return (signal_pad, index, signal, update, op, sem) |
| 104 | + |
| 105 | + |
| 106 | +@_decorators.register_fake(wait) |
| 107 | +def _( |
| 108 | + signal_pad: torch.Tensor, |
| 109 | + index: list[object], |
| 110 | + signal: int = 1, |
| 111 | + update: int | None = None, |
| 112 | + op: str = "ld", |
| 113 | + sem: str = "acquire", |
| 114 | + scope: str = "sys", |
| 115 | +) -> None: |
| 116 | + return None |
| 117 | + |
| 118 | + |
| 119 | +def get_lock_spin_ptx(name: str, op: str, sem: str, scope: str): |
| 120 | + ptx_acquire_list = { |
| 121 | + "ld": f"ld.global.{sem}.{scope}.u32 $0, [$1];", |
| 122 | + "atomic_cas": f"atom.global.{sem}.{scope}.cas.b32 $0, [$1], $2, $3;", |
| 123 | + "atomic_add": f"atom.global.{sem}.{scope}.add.u32 $0, [$1], $2;", |
| 124 | + } |
| 125 | + |
| 126 | + acquire_lock_expr = ptx_acquire_list[op] |
| 127 | + ptx_template = f''' |
| 128 | +tl.inline_asm_elementwise(""" |
| 129 | + {{ |
| 130 | + .reg .u32 %tmp32_<1>; |
| 131 | + .reg .pred %p<2>; |
| 132 | +
|
| 133 | + // calculate tid assuming tid.y=tid.z=1. TODO: get this from Triton |
| 134 | + mov.u32 %tmp32_0, %tid.x; |
| 135 | + setp.eq.s32 %p1, %tmp32_0, 0; |
| 136 | +
|
| 137 | + mov.u32 $0, 0; |
| 138 | + // initialize tmp_0 to 0 |
| 139 | + wait_block: |
| 140 | + @%p1 {acquire_lock_expr} |
| 141 | + setp.ne.u32 %p0, $0, $2; |
| 142 | + and.pred %p0, %p0, %p1; |
| 143 | + @%p0 bra wait_block; |
| 144 | + bar.sync 0; |
| 145 | + }} |
| 146 | + """, |
| 147 | + "=r, l, r, r", |
| 148 | + [{name} + offset, signal, update], |
| 149 | + dtype={name}.dtype.element_ty, |
| 150 | + is_pure=False, |
| 151 | + pack=1, |
| 152 | +) |
| 153 | +''' |
| 154 | + print("ptx_template", ptx_template) |
| 155 | + return ptx_template |
| 156 | + |
| 157 | + |
| 158 | +@_decorators.codegen(wait) |
| 159 | +def _(state: CodegenState) -> ast.AST: |
| 160 | + import ast |
| 161 | + |
| 162 | + from .._compiler.ast_extension import expr_from_string |
| 163 | + from .._compiler.indexing_strategy import SubscriptIndexing |
| 164 | + |
| 165 | + signal_pad = state.proxy_arg(0) |
| 166 | + index = state.proxy_arg(1) |
| 167 | + signal = state.proxy_arg(2) |
| 168 | + update = state.proxy_arg(3) |
| 169 | + op = state.proxy_arg(4) |
| 170 | + sem = state.proxy_arg(5) |
| 171 | + scope = state.proxy_arg(6) |
| 172 | + |
| 173 | + assert isinstance(signal_pad, torch.Tensor) |
| 174 | + assert isinstance(index, (list)) |
| 175 | + |
| 176 | + print(index, "index") |
| 177 | + indices = SubscriptIndexing.create(state, signal_pad, index) |
| 178 | + print("indices", indices) |
| 179 | + signal_pad_name = state.device_function.tensor_arg(signal_pad).name |
| 180 | + |
| 181 | + signal_expr = ast.Constant(value=signal) |
| 182 | + update_expr = ast.Constant(value=update) |
| 183 | + |
| 184 | + lock_spin_ptx = get_lock_spin_ptx(signal_pad_name, op, sem, scope) |
| 185 | + |
| 186 | + return expr_from_string( |
| 187 | + lock_spin_ptx, |
| 188 | + offset=indices.index_expr, |
| 189 | + signal=signal_expr, |
| 190 | + update=update_expr, |
| 191 | + ) |
| 192 | + |
| 193 | + |
| 194 | +@has_side_effect |
| 195 | +@_decorators.api(tiles_as_sizes=True) |
| 196 | +def signal( |
| 197 | + signal_pad: torch.Tensor, |
| 198 | + index: list[object], |
| 199 | + signal: int = 1, |
| 200 | + sem: str = "ld.acquire", |
| 201 | +) -> None: |
| 202 | + """ |
| 203 | + Wait for a signal before accessing the data tensor. |
| 204 | + Args: |
| 205 | + signal_pad: The signal tensor to wait on |
| 206 | + index: Indices into signal_pad tensor for which signal to wait for |
| 207 | + signal: the signal to wait for |
| 208 | + update: update the signal_pad after acquiring the signal. |
| 209 | + sem: The memory op for acquring the lock (default: 'ld.acquire') |
| 210 | +
|
| 211 | + Returns: |
| 212 | + None |
| 213 | + """ |
| 214 | + raise exc.NotInsideKernel |
0 commit comments