Skip to content

Commit 9858b0c

Browse files
authored
Add hl.grid(...) support (#59)
1 parent b2306ff commit 9858b0c

File tree

9 files changed

+583
-46
lines changed

9 files changed

+583
-46
lines changed

helion/_compiler/compile_environment.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,9 @@ def from_config_assert(self, config: Config) -> int | torch.SymInt:
311311
def is_flattened(self, config: Config) -> bool:
312312
return self.block_size_source.is_flattened(config)
313313

314+
def is_grid(self) -> bool:
315+
return self.block_size_source.is_grid()
316+
314317
def get_order(self, config: Config, count: int) -> list[int]:
315318
return self.block_size_source.get_order(config, count)
316319

@@ -330,6 +333,9 @@ def from_config(self, config: Config) -> int | torch.SymInt | None:
330333
def is_flattened(self, config: Config) -> bool:
331334
return False
332335

336+
def is_grid(self) -> bool:
337+
return False
338+
333339
def get_order(self, config: Config, count: int) -> list[int]:
334340
return [*range(count)]
335341

@@ -348,6 +354,15 @@ def from_config(self, config: Config) -> int | torch.SymInt:
348354
return self.value
349355

350356

357+
@dataclasses.dataclass
358+
class GridBlockSizeSource(BlockSizeSource):
359+
def from_config(self, config: Config) -> int:
360+
raise NotImplementedError
361+
362+
def is_grid(self) -> bool:
363+
return True
364+
365+
351366
@dataclasses.dataclass
352367
class LoopSpecBlockSizeSource(BlockSizeSource):
353368
loop_spec: int

helion/_compiler/device_ir.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .source_location import current_location
4141
from .tile_index_proxy import CheckForIndexCalls
4242
from .tile_index_proxy import TileIndexProxy
43+
from .type_propagation import GridIndexType
4344
from .type_propagation import IterType
4445
from .type_propagation import SequenceType
4546
from .type_propagation import TensorType
@@ -464,7 +465,9 @@ def run_subgraph(*args: object) -> list[object]:
464465
iter_vars = inner_type.unpack()
465466
else:
466467
iter_vars = [inner_type]
467-
assert all(isinstance(x, TileIndexType) for x in iter_vars)
468+
assert all(
469+
isinstance(x, (TileIndexType, GridIndexType)) for x in iter_vars
470+
)
468471
graph_idx = self.device_ir.add_graph(
469472
graph,
470473
ForLoopGraphInfo,

helion/_compiler/indexing_strategy.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,13 @@ def compute_shape(
169169
if isinstance(symbol, sympy.Symbol):
170170
origin = HostFunction.current().symbol_to_origin.get(symbol.name)
171171
if origin and isinstance(origin.origin, BlockSizeOrigin):
172-
if tensor.size(tensor.ndim - len(input_size) - 1) != 1:
172+
if (
173+
CompileEnvironment.current()
174+
.block_sizes[origin.origin.block_size_idx]
175+
.is_grid()
176+
):
177+
pass
178+
elif tensor.size(tensor.ndim - len(input_size) - 1) != 1:
173179
output_size.append(k)
174180
else:
175181
output_size.append(1)
@@ -200,6 +206,7 @@ def create(
200206
mask_values = {}
201207
output_size = SubscriptIndexing.compute_shape(fake_value, index)
202208
dtype = CompileEnvironment.current().triton_index_type()
209+
first_non_grid_index = 0
203210
for n, k in enumerate(index):
204211
if k is None:
205212
output_idx += 1
@@ -210,8 +217,18 @@ def create(
210217
origin = None
211218
if isinstance(symbol, sympy.Symbol):
212219
origin = HostFunction.current().symbol_to_origin.get(symbol.name)
213-
expand = tile_strategy.expand_str(output_size, output_idx)
214220
if origin and isinstance(origin.origin, BlockSizeOrigin):
221+
if (
222+
CompileEnvironment.current()
223+
.block_sizes[origin.origin.block_size_idx]
224+
.is_grid()
225+
):
226+
first_non_grid_index = n + 1
227+
expand = tile_strategy.expand_str(output_size, output_idx)
228+
else:
229+
expand = tile_strategy.expand_str(
230+
output_size, output_idx - first_non_grid_index
231+
)
215232
index_var = state.codegen.index_var(origin.origin.block_size_idx)
216233
i = len(index_values)
217234
index_values.append(f"({index_var}){expand}")
@@ -221,10 +238,15 @@ def create(
221238
mask_values.setdefault(f"({mask}){expand}")
222239
output_idx += 1
223240
else:
241+
expand = tile_strategy.expand_str(
242+
output_size, output_idx - first_non_grid_index
243+
)
224244
val = state.device_function.literal_expr(k)
225245
index_values.append(f"tl.full([1], {val}, {dtype}){expand}")
226246
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
227-
expand = tile_strategy.expand_str(output_size, output_idx)
247+
expand = tile_strategy.expand_str(
248+
output_size, output_idx - first_non_grid_index
249+
)
228250
size = fake_value.size(len(index_values))
229251
if size != 1:
230252
env = CompileEnvironment.current()
@@ -238,21 +260,25 @@ def create(
238260
index_values.append(f"tl.zeros([1], {dtype}){expand}")
239261
output_idx += 1
240262
elif isinstance(k, torch.Tensor) and k.ndim == 1:
241-
expand = tile_strategy.expand_str(output_size, output_idx)
263+
expand = tile_strategy.expand_str(
264+
output_size, output_idx - first_non_grid_index
265+
)
242266
ast_index = state.ast_args[1]
243267
assert isinstance(ast_index, (list, tuple))
244268
assert len(ast_index) == len(index)
245269
index_var = state.codegen.lift(ast_index[n]).id
246270
index_values.append(f"({index_var}){expand}")
247271
if (
248-
block_idx := TileStrategy.get_block_index(output_size[output_idx])
272+
block_idx := TileStrategy.get_block_index(
273+
output_size[output_idx - first_non_grid_index]
274+
)
249275
) is not None:
250276
if mask := state.codegen.mask_var(block_idx):
251277
mask_values.setdefault(f"({mask}){expand}")
252278
output_idx += 1
253279
else:
254280
raise exc.InvalidIndexingType(k)
255-
assert len(output_size) == output_idx
281+
assert len(output_size) == output_idx - first_non_grid_index
256282
assert len(index_values) == fake_value.ndim
257283

258284
index_expr = []

helion/_compiler/tile_dispatch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from helion._compiler.tile_strategy import DeviceGridState
1616
from helion._compiler.tile_strategy import DeviceLoopState
1717
from helion._compiler.tile_strategy import FlattenedTileStrategy
18+
from helion._compiler.tile_strategy import NDGridTileStrategy
1819
from helion._compiler.tile_strategy import NDTileStrategy
1920
from helion._compiler.tile_strategy import TileStrategy
2021

@@ -60,7 +61,13 @@ def _add_loop_strategy(
6061
env = CompileEnvironment.current()
6162
block_size_infos = [env.block_sizes[i] for i in block_indices]
6263
loop_order = block_size_infos[0].get_order(config, len(block_size_infos))
63-
if block_size_infos[0].is_flattened(config):
64+
if block_size_infos[0].is_grid():
65+
strategy: TileStrategy = NDGridTileStrategy(
66+
fn,
67+
block_indices,
68+
loop_order=loop_order,
69+
)
70+
elif block_size_infos[0].is_flattened(config):
6471
strategy: TileStrategy = FlattenedTileStrategy(
6572
fn,
6673
block_indices,

helion/_compiler/tile_strategy.py

Lines changed: 74 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,7 @@ def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
346346
return output
347347

348348

349-
class NDTileStrategy(BlockSizeTileStrategy):
350-
"""Do up to 3D tiling using the kernel grid."""
351-
349+
class _BaseNDTileStrategy(BlockSizeTileStrategy):
352350
block_size: list[SymIntLike]
353351

354352
def __init__(
@@ -357,21 +355,15 @@ def __init__(
357355
block_indices: list[int],
358356
block_size: list[SymIntLike] | SymIntLike,
359357
loop_order: list[int],
360-
l2_grouping: int,
361358
) -> None:
362359
assert isinstance(block_size, list)
363360
super().__init__(fn, block_indices, block_size, loop_order)
364-
self.mask_vars: dict[int, str | None] = {}
365-
self.l2_grouping = l2_grouping
366361
for bs, block_idx in zip(block_size, block_indices, strict=True):
367362
if (block_idx,) not in fn.block_size_var_cache and bs != 1:
368363
fn.block_size_var_cache[(block_idx,)] = fn.new_var(
369364
f"_BLOCK_SIZE_{block_idx}"
370365
)
371366

372-
def mask_var(self, block_idx: int) -> str | None:
373-
return self.mask_vars[block_idx]
374-
375367
def codegen_grid(self, state: CodegenState) -> None:
376368
block_indices = self.block_indices
377369
env = CompileEnvironment.current()
@@ -408,34 +400,16 @@ def codegen_grid(self, state: CodegenState) -> None:
408400
state.add_statement(
409401
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
410402
)
411-
mask_statement = self._setup_mask(state, block_idx, block_size, index_var)
412-
if mask_statement is not None:
413-
state.add_statement(mask_statement)
403+
if hasattr(self, "_setup_mask"):
404+
mask_statement = self._setup_mask( # pyre-ignore[16]
405+
state, block_idx, block_size, index_var
406+
)
407+
if mask_statement is not None:
408+
state.add_statement(mask_statement)
414409
pids.append(ProgramID(pid_var, block_size_var, numel))
415410
pids.codegen(state)
416411

417-
def _setup_mask(
418-
self,
419-
state: CodegenState,
420-
block_idx: int,
421-
block_size: SymIntLike,
422-
index_var: str,
423-
) -> ast.stmt | None:
424-
env = CompileEnvironment.current()
425-
numel = env.block_sizes[block_idx].numel
426-
if block_size == 1 or env.known_multiple(numel, block_size):
427-
self.mask_vars[block_idx] = None
428-
return None
429-
self.mask_vars[block_idx] = mask_var = self.fn.new_var(
430-
f"mask_{block_idx}", dce=True
431-
)
432-
return statement_from_string(
433-
f"{mask_var} = ({index_var} < ({state.device_function.sympy_expr(numel)}))"
434-
)
435-
436412
def select_pid_strategy(self) -> ProgramIDs:
437-
if self.l2_grouping > 1:
438-
return L2GroupingProgramIDs(group_size=self.l2_grouping)
439413
if 1 < len(self.block_indices) <= 3 and self.fn.config.use_yz_grid:
440414
return GridProgramIDs()
441415
return VirtualProgramIDs()
@@ -483,9 +457,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
483457
f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})"
484458
),
485459
]
486-
mask_statement = self._setup_mask(state, block_idx, block_size, index_var)
487-
if mask_statement is not None:
488-
extra_body.append(mask_statement)
460+
if hasattr(self, "_setup_mask"):
461+
mask_statement = self._setup_mask( # pyre-ignore[16]
462+
state, block_idx, block_size, index_var
463+
)
464+
if mask_statement is not None:
465+
extra_body.append(mask_statement)
489466
body[:] = [*extra_body, *body]
490467
body = [for_node]
491468
assert for_node is not None
@@ -500,6 +477,67 @@ def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
500477
return shapes
501478

502479

480+
class NDTileStrategy(_BaseNDTileStrategy):
481+
"""Do up to 3D tiling using the kernel grid."""
482+
483+
def __init__(
484+
self,
485+
fn: DeviceFunction,
486+
block_indices: list[int],
487+
block_size: list[SymIntLike] | SymIntLike,
488+
loop_order: list[int],
489+
l2_grouping: int,
490+
) -> None:
491+
super().__init__(fn, block_indices, block_size, loop_order)
492+
self.mask_vars: dict[int, str | None] = {}
493+
self.l2_grouping = l2_grouping
494+
495+
def mask_var(self, block_idx: int) -> str | None:
496+
return self.mask_vars[block_idx]
497+
498+
def _setup_mask(
499+
self,
500+
state: CodegenState,
501+
block_idx: int,
502+
block_size: SymIntLike,
503+
index_var: str,
504+
) -> ast.stmt | None:
505+
env = CompileEnvironment.current()
506+
numel = env.block_sizes[block_idx].numel
507+
if block_size == 1 or env.known_multiple(numel, block_size):
508+
self.mask_vars[block_idx] = None
509+
return None
510+
self.mask_vars[block_idx] = mask_var = self.fn.new_var(
511+
f"mask_{block_idx}", dce=True
512+
)
513+
return statement_from_string(
514+
f"{mask_var} = ({index_var} < ({state.device_function.sympy_expr(numel)}))"
515+
)
516+
517+
def select_pid_strategy(self) -> ProgramIDs:
518+
if self.l2_grouping > 1:
519+
return L2GroupingProgramIDs(group_size=self.l2_grouping)
520+
return super().select_pid_strategy()
521+
522+
523+
class NDGridTileStrategy(_BaseNDTileStrategy):
524+
def __init__(
525+
self,
526+
fn: DeviceFunction,
527+
block_indices: list[int],
528+
loop_order: list[int],
529+
) -> None:
530+
super().__init__(
531+
fn=fn,
532+
block_indices=block_indices,
533+
block_size=[1] * len(block_indices), # pyre-ignore[6]
534+
loop_order=loop_order,
535+
)
536+
537+
def mask_var(self, block_idx: int) -> str | None:
538+
return None
539+
540+
503541
class CompactedShape(NamedTuple):
504542
size_str: str
505543
user_indices: list[int]

helion/_compiler/type_propagation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,39 @@ def merge(self, other: TypeInfo) -> TypeInfo:
976976
return super().merge(other)
977977

978978

979+
class GridIndexType(SymIntType):
980+
block_size_idx: int
981+
982+
def __init__(self, origin: Origin, block_size_idx: int) -> None:
983+
from .._compiler.compile_environment import CompileEnvironment
984+
985+
env = CompileEnvironment.current()
986+
super().__init__(origin, env.block_sizes[block_size_idx].var)
987+
self.block_size_idx = block_size_idx
988+
989+
def __str__(self) -> str: # pragma: no cover – debug helper
990+
return f"{type(self).__name__}({self.block_size_idx})"
991+
992+
@staticmethod
993+
def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
994+
from .._compiler.compile_environment import CompileEnvironment
995+
from .._compiler.compile_environment import GridBlockSizeSource
996+
997+
env = CompileEnvironment.current()
998+
block_idx = env.allocate_block_size(numel, source=GridBlockSizeSource())
999+
return GridIndexType(origin, block_idx)
1000+
1001+
def merge(self, other: TypeInfo) -> TypeInfo: # type: ignore[override]
1002+
if isinstance(other, GridIndexType):
1003+
if self.block_size_idx == other.block_size_idx:
1004+
return self
1005+
return UnknownType(
1006+
debug_msg=f"GridIndexType mismatch in control flow: {self.block_size_idx} vs {other.block_size_idx}",
1007+
origin=other.origin,
1008+
)
1009+
return super().merge(other)
1010+
1011+
9791012
class IterType(TypeInfo):
9801013
inner: TypeInfo
9811014

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .constexpr import ConstExpr as constexpr # noqa: F401
44
from .creation_ops import full as full
55
from .creation_ops import zeros as zeros
6+
from .loops import grid as grid
67
from .loops import register_block_size as register_block_size
78
from .loops import tile as tile
89
from .memory_ops import load as load

0 commit comments

Comments
 (0)