Skip to content

Commit e3c4590

Browse files
committed
Add support for hl.tile(begin, end) and hl.tile(begin, end, block_size)
1 parent 296d12d commit e3c4590

File tree

4 files changed

+276
-92
lines changed

4 files changed

+276
-92
lines changed

helion/_compiler/compile_environment.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ class NoCurrentEnvironment(RuntimeError):
295295
pass
296296

297297

298-
class BlockSizeInfo(typing.NamedTuple):
298+
@dataclasses.dataclass
299+
class BlockSizeInfo:
299300
"""
300301
Information about a block size.
301302
Used to track the block size for a given dimension.
@@ -320,8 +321,14 @@ def known_multiple(self, block_size: int | torch.SymInt) -> bool:
320321
return CompileEnvironment.current().known_multiple(self.numel, block_size)
321322

322323
def size_hint(self) -> int:
323-
assert self.size is not None
324-
return CompileEnvironment.current().size_hint(self.size)
324+
size = self.size
325+
assert size is not None
326+
return CompileEnvironment.current().size_hint(size)
327+
328+
def mark_alternate_size(self, size: torch.SymInt | int | None) -> None:
329+
"""If a block size is used with a different size, we need to clear the hint to enable masking."""
330+
if size is None or self.size is None or self.size != size:
331+
self.size = None
325332

326333
def symbol(self) -> sympy.Symbol:
327334
return self.var._sympy_()

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def allocate(
979979

980980
@staticmethod
981981
def allocate_fixed(
982-
numel: int | torch.SymInt, block_size: int | torch.SymInt, origin: Origin
982+
numel: int | torch.SymInt | None, block_size: int | torch.SymInt, origin: Origin
983983
) -> TileIndexType:
984984
env = CompileEnvironment.current()
985985
return TileIndexType(

helion/language/loops.py

Lines changed: 137 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING
55
from typing import Iterator
66
from typing import Sequence
7+
from typing import TypeGuard
78
from typing import overload
89

910
import torch
@@ -12,14 +13,12 @@
1213
from .._compiler.ast_extension import ExtendedAST
1314
from .._compiler.ast_extension import LoopType
1415
from .._compiler.ast_extension import expr_from_string
16+
from .._compiler.compile_environment import CompileEnvironment
1517
from .._compiler.tile_index_proxy import TileIndexProxy
1618
from .._compiler.type_propagation import GridIndexType
1719
from .._compiler.type_propagation import IterType
18-
from .._compiler.type_propagation import LiteralType
1920
from .._compiler.type_propagation import Origin
2021
from .._compiler.type_propagation import SequenceType
21-
from .._compiler.type_propagation import SymIntType
22-
from .._compiler.type_propagation import TensorType
2322
from .._compiler.type_propagation import TileIndexType
2423
from .._compiler.type_propagation import TypeInfo
2524
from .._compiler.type_propagation import UnknownType
@@ -40,23 +39,32 @@
4039
@_decorators.api(
4140
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
4241
)
43-
def tile(sizes: int, /, block_size: object = None) -> Iterator[TileOutput]: ...
42+
def tile(
43+
begin_or_end: int,
44+
end_or_none: int | None = None,
45+
/,
46+
block_size: object = None,
47+
) -> Iterator[TileOutput]: ...
4448

4549

4650
@overload
4751
@_decorators.api(
4852
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
4953
)
5054
def tile(
51-
sizes: Sequence[int], /, block_size: object = None
55+
begin_or_end: Sequence[int],
56+
end_or_none: Sequence[int] | None = None,
57+
/,
58+
block_size: object = None,
5259
) -> Iterator[Sequence[TileOutput]]: ...
5360

5461

5562
@_decorators.api(
5663
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
5764
)
5865
def tile(
59-
sizes: int | Sequence[int],
66+
begin_or_end: int | Sequence[int],
67+
end_or_none: int | Sequence[int] | None = None,
6068
/,
6169
block_size: object = None,
6270
) -> Iterator[TileOutput] | Iterator[Sequence[TileOutput]]:
@@ -73,6 +81,16 @@ def tile(
7381
If used at the top level of a function, this becomes the grid of the kernel.
7482
Otherwise, it becomes a loop in the output kernel.
7583
84+
Similar to `range()` there are multiple forms of this function:
85+
tile(end) iterates from 0 to `end - 1`, with autotuned block_size.
86+
tile(begin, end) iterates from `begin` to `end - 1`, with autotuned block_size.
87+
tile(begin, end, block_size) iterates from `begin` to `end - 1`, with the given block_size.
88+
tile(end, block_size=block_size) iterates from 0 to `end - 1`, with the given block_size.
89+
90+
begin/end/block_size can be a single integer or a sequence of integers to specify
91+
multidimensional iteration. Block sizes can be explicitly registered for autotuning
92+
with `hl.register_block_size()`.
93+
7694
Examples:
7795
7896
for tile in hl.tile(1000):
@@ -81,51 +99,116 @@ def tile(
8199
for tile0, tile1 in hl.tile([1000, 1000]):
82100
...
83101
84-
:param sizes: An integer or a sequence of integers representing the sizes for tiling.
102+
:param begin_or_end: If 2 or more positional arguments are provided, the start of the iteration space. Otherwise, the end of the iteration space.
103+
:param end_or_none: If 2 or more positional arguments are provided, the end of the iteration space.
85104
:return: A TileIndexProtocol object if a single size is provided, or a sequence of TileIndexProtocol objects if a sequence of sizes is provided.
86105
"""
87106
raise exc.NotInsideKernel
88107

89108

109+
def _not_none(value: TypeInfo | None) -> TypeGuard[TypeInfo]:
110+
return not (value is None or value.is_literal() and value.as_literal() is None)
111+
112+
113+
def _to_proxy(value: TypeInfo) -> object:
114+
try:
115+
return value.proxy()
116+
except NotImplementedError:
117+
raise exc.IncorrectTileUsage(
118+
f"expected IntLike or list[IntLike], got {value!s}"
119+
) from None
120+
121+
122+
def _check_matching(a: object, b: object) -> None:
123+
"""Check that the types of `a` and `b` match for use in hl.tile."""
124+
if isinstance(a, (list, tuple)):
125+
if not isinstance(b, (list, tuple)):
126+
raise exc.IncorrectTileUsage(
127+
f"expected type hl.tile args to match, got {type(a)} and {type(b)}"
128+
)
129+
if len(a) != len(b):
130+
raise exc.IncorrectTileUsage(
131+
f"expected dims for hl.tile args to match, got {len(a)} and {len(b)}"
132+
)
133+
elif isinstance(a, (int, torch.SymInt, torch.Tensor)):
134+
if not isinstance(b, (int, torch.SymInt, torch.Tensor)):
135+
raise exc.IncorrectTileUsage(
136+
f"expected type hl.tile args to match, got {type(a)} and {type(b)}"
137+
)
138+
else:
139+
raise exc.IncorrectTileUsage(
140+
f"expected type hl.tile args to be IntLike or list[IntLike], got {type(a)}"
141+
)
142+
143+
144+
def _normalize_begin_end(
145+
begin_or_end: TypeInfo,
146+
end_or_none: TypeInfo | None,
147+
origin: Origin,
148+
) -> tuple[TypeInfo, TypeInfo]:
149+
"""Fill in defaults for begin if it is not provided."""
150+
if _not_none(end_or_none):
151+
begin = begin_or_end
152+
end = end_or_none
153+
else:
154+
try:
155+
begin = TypeInfo.from_example(begin_or_end.tree_map(lambda n: 0), origin)
156+
except NotImplementedError:
157+
raise exc.TypePropagationError(
158+
UnknownType(
159+
origin,
160+
f"expected IntLike or list[IntLike], got {begin_or_end!s}",
161+
chained_from=begin_or_end,
162+
)
163+
) from None
164+
end = begin_or_end
165+
return begin, end
166+
167+
90168
@_decorators.type_propagation(tile)
91169
def _(
92-
sizes: TypeInfo, block_size: TypeInfo | None = None, *, origin: Origin
170+
begin_or_end: TypeInfo,
171+
end_or_none: TypeInfo | None = None,
172+
/,
173+
block_size: TypeInfo | None = None,
174+
*,
175+
origin: Origin,
93176
) -> TypeInfo:
94177
parent = ExtendedAST.current()[-2]
95178
if not isinstance(parent, ast.For):
96179
raise exc.LoopFunctionNotInFor("tile")
180+
begin, end = _normalize_begin_end(begin_or_end, end_or_none, origin=origin)
181+
proxy_begin = _to_proxy(begin)
182+
proxy_end = _to_proxy(end)
183+
_check_matching(proxy_begin, proxy_end)
184+
if _not_none(block_size):
185+
proxy_block_size = TileIndexProxy.tiles_to_sizes(_to_proxy(block_size))
186+
_check_matching(proxy_end, proxy_block_size)
187+
else:
188+
proxy_block_size = begin.tree_map(lambda n: None)
189+
190+
if unpack := not isinstance(proxy_end, (list, tuple)):
191+
proxy_begin = [proxy_begin]
192+
proxy_end = [proxy_end]
193+
proxy_block_size = [proxy_block_size]
194+
97195
if (
98-
block_size is None
99-
or block_size.is_literal()
100-
and block_size.as_literal() is None
196+
all(bs is None for bs in proxy_block_size)
197+
and all(isinstance(s, (int, torch.SymInt)) for s in proxy_begin)
198+
and all(isinstance(s, (int, torch.SymInt)) for s in proxy_end)
101199
):
102-
result = _register_block_size_types(sizes, origin)
200+
proxy_size = [e - b for b, e in zip(proxy_begin, proxy_end, strict=True)]
201+
results = TileIndexType.allocate(proxy_size, origin)
103202
else:
104-
try:
105-
proxy_sizes = sizes.proxy()
106-
proxy_block_size = TileIndexProxy.tiles_to_sizes(block_size.proxy())
107-
except NotImplementedError:
108-
raise exc.IncorrectTileUsage(
109-
f"expected int or list[int], got {sizes!s} and {block_size!s}"
110-
) from None
111-
if isinstance(proxy_sizes, (list, tuple)):
112-
if not isinstance(proxy_block_size, (list, tuple)) or len(
113-
proxy_sizes
114-
) != len(proxy_block_size):
115-
raise exc.IncorrectTileUsage(
116-
f"expected dims for sizes and block_sizes to match, got {sizes!s} and {block_size!s}"
117-
)
118-
unpack = False
119-
else:
120-
if not isinstance(proxy_block_size, int | torch.SymInt):
121-
raise exc.IncorrectTileUsage(
122-
f"expected type for sizes and block_sizes to match, got {sizes!s} and {block_size!s}"
123-
)
124-
proxy_sizes = [proxy_sizes]
125-
proxy_block_size = [proxy_block_size]
126-
unpack = True
203+
# we must allocate the block sizes individually due to data dependent size or pre-allocated block sizes
204+
# TODO(jansel): this flattens the structure of the config, which we should avoid
127205
results = []
128-
for size, bs in zip(proxy_sizes, proxy_block_size, strict=True):
206+
for begin_part, end_part, bs in zip(
207+
proxy_begin, proxy_end, proxy_block_size, strict=True
208+
):
209+
size = end_part - begin_part
210+
if isinstance(size, torch.Tensor):
211+
size = None # data dependent size
129212
if bs is None:
130213
results.append(TileIndexType.allocate([size], origin)[0])
131214
elif isinstance(bs, int):
@@ -138,59 +221,14 @@ def _(
138221
results.append(TileIndexType.allocate_fixed(size, bs, origin))
139222
else:
140223
results.append(TileIndexType(origin=origin, block_size_idx=index))
141-
if unpack:
142-
(result,) = results
143-
else:
144-
result = SequenceType(origin, results)
145-
return IterType(origin, result)
146-
147-
148-
def _register_block_size_types(sizes: TypeInfo, origin: Origin) -> TypeInfo:
149-
if isinstance(sizes, SequenceType):
150-
unpacked = sizes.unpack()
224+
CompileEnvironment.current().block_sizes[index].mark_alternate_size(
225+
size
226+
)
227+
if unpack:
228+
(result,) = results
151229
else:
152-
unpacked = [sizes]
153-
has_data_dependency = False
154-
for size in unpacked:
155-
if isinstance(size, TensorType) and size.origin.is_device():
156-
has_data_dependency = True
157-
elif isinstance(size, (LiteralType, SymIntType)) and isinstance(
158-
size.proxy(), (int, torch.SymInt)
159-
):
160-
pass
161-
else:
162-
raise exc.TypePropagationError(
163-
UnknownType(
164-
origin,
165-
f"tile() expected int or list[int], got {size!s}",
166-
chained_from=size,
167-
)
168-
)
169-
if has_data_dependency:
170-
# TODO(jansel): support flatten/reorder for data dependencies
171-
inner_types: list[TypeInfo] = []
172-
for size in unpacked:
173-
if isinstance(size, TensorType) and size.origin.is_device():
174-
proxy = None
175-
else:
176-
proxy = size.proxy()
177-
assert isinstance(proxy, (int, torch.SymInt))
178-
inner_types.append(TileIndexType.allocate([proxy], origin)[0])
179-
if isinstance(sizes, SequenceType):
180-
return SequenceType(
181-
origin=origin,
182-
element_types=inner_types,
183-
)
184-
assert len(inner_types) == 1
185-
return inner_types[0]
186-
proxy_sizes = sizes.proxy()
187-
if isinstance(proxy_sizes, (int, torch.SymInt)):
188-
return TileIndexType.allocate([proxy_sizes], origin)[0]
189-
return SequenceType(
190-
origin=origin,
191-
# pyre-fixme[6]
192-
element_types=TileIndexType.allocate(proxy_sizes, origin),
193-
)
230+
result = SequenceType(origin, results)
231+
return IterType(origin, result)
194232

195233

196234
def _get_block_indices(type_info: TypeInfo) -> list[int]:
@@ -334,6 +372,17 @@ def register_block_size(size: int | Sequence[int]) -> TileOutput | Sequence[Tile
334372
raise exc.NotInsideKernel
335373

336374

375+
def _register_block_size_types(sizes: TypeInfo, origin: Origin) -> TypeInfo:
376+
proxy_sizes = sizes.proxy()
377+
if isinstance(proxy_sizes, (int, torch.SymInt)):
378+
return TileIndexType.allocate([proxy_sizes], origin)[0]
379+
return SequenceType(
380+
origin=origin,
381+
# pyre-fixme[6]
382+
element_types=TileIndexType.allocate(proxy_sizes, origin),
383+
)
384+
385+
337386
@_decorators.type_propagation(register_block_size)
338387
def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo:
339388
return _register_block_size_types(sizes, origin)

0 commit comments

Comments
 (0)