4
4
from typing import TYPE_CHECKING
5
5
from typing import Iterator
6
6
from typing import Sequence
7
+ from typing import TypeGuard
7
8
from typing import overload
8
9
9
10
import torch
12
13
from .._compiler .ast_extension import ExtendedAST
13
14
from .._compiler .ast_extension import LoopType
14
15
from .._compiler .ast_extension import expr_from_string
16
+ from .._compiler .compile_environment import CompileEnvironment
15
17
from .._compiler .tile_index_proxy import TileIndexProxy
16
18
from .._compiler .type_propagation import GridIndexType
17
19
from .._compiler .type_propagation import IterType
18
- from .._compiler .type_propagation import LiteralType
19
20
from .._compiler .type_propagation import Origin
20
21
from .._compiler .type_propagation import SequenceType
21
- from .._compiler .type_propagation import SymIntType
22
- from .._compiler .type_propagation import TensorType
23
22
from .._compiler .type_propagation import TileIndexType
24
23
from .._compiler .type_propagation import TypeInfo
25
24
from .._compiler .type_propagation import UnknownType
40
39
@_decorators .api (
41
40
is_device_loop = True , is_device_only = False , cache_type = True , tiles_as_sizes = True
42
41
)
43
- def tile (sizes : int , / , block_size : object = None ) -> Iterator [TileOutput ]: ...
42
+ def tile (
43
+ begin_or_end : int ,
44
+ end_or_none : int | None = None ,
45
+ / ,
46
+ block_size : object = None ,
47
+ ) -> Iterator [TileOutput ]: ...
44
48
45
49
46
50
@overload
47
51
@_decorators .api (
48
52
is_device_loop = True , is_device_only = False , cache_type = True , tiles_as_sizes = True
49
53
)
50
54
def tile (
51
- sizes : Sequence [int ], / , block_size : object = None
55
+ begin_or_end : Sequence [int ],
56
+ end_or_none : Sequence [int ] | None = None ,
57
+ / ,
58
+ block_size : object = None ,
52
59
) -> Iterator [Sequence [TileOutput ]]: ...
53
60
54
61
55
62
@_decorators .api (
56
63
is_device_loop = True , is_device_only = False , cache_type = True , tiles_as_sizes = True
57
64
)
58
65
def tile (
59
- sizes : int | Sequence [int ],
66
+ begin_or_end : int | Sequence [int ],
67
+ end_or_none : int | Sequence [int ] | None = None ,
60
68
/ ,
61
69
block_size : object = None ,
62
70
) -> Iterator [TileOutput ] | Iterator [Sequence [TileOutput ]]:
@@ -73,6 +81,16 @@ def tile(
73
81
If used at the top level of a function, this becomes the grid of the kernel.
74
82
Otherwise, it becomes a loop in the output kernel.
75
83
84
+ Similar to `range()` there are multiple forms of this function:
85
+ tile(end) iterates from 0 to `end - 1`, with autotuned block_size.
86
+ tile(begin, end) iterates from `begin` to `end - 1`, with autotuned block_size.
87
+ tile(begin, end, block_size) iterates from `begin` to `end - 1`, with the given block_size.
88
+ tile(end, block_size=block_size) iterates from 0 to `end - 1`, with the given block_size.
89
+
90
+ begin/end/block_size can be a single integer or a sequence of integers to specify
91
+ multidimensional iteration. Block sizes can be explicitly registered for autotuning
92
+ with `hl.register_block_size()`.
93
+
76
94
Examples:
77
95
78
96
for tile in hl.tile(1000):
@@ -81,51 +99,116 @@ def tile(
81
99
for tile0, tile1 in hl.tile([1000, 1000]):
82
100
...
83
101
84
- :param sizes: An integer or a sequence of integers representing the sizes for tiling.
102
+ :param begin_or_end: If 2 or more positional arguments are provided, the start of the iteration space. Otherwise, the end of the iteration space.
103
+ :param end_or_none: If 2 or more positional arguments are provided, the end of the iteration space.
85
104
:return: A TileIndexProtocol object if a single size is provided, or a sequence of TileIndexProtocol objects if a sequence of sizes is provided.
86
105
"""
87
106
raise exc .NotInsideKernel
88
107
89
108
109
+ def _not_none (value : TypeInfo | None ) -> TypeGuard [TypeInfo ]:
110
+ return not (value is None or value .is_literal () and value .as_literal () is None )
111
+
112
+
113
+ def _to_proxy (value : TypeInfo ) -> object :
114
+ try :
115
+ return value .proxy ()
116
+ except NotImplementedError :
117
+ raise exc .IncorrectTileUsage (
118
+ f"expected IntLike or list[IntLike], got { value !s} "
119
+ ) from None
120
+
121
+
122
+ def _check_matching (a : object , b : object ) -> None :
123
+ """Check that the types of `a` and `b` match for use in hl.tile."""
124
+ if isinstance (a , (list , tuple )):
125
+ if not isinstance (b , (list , tuple )):
126
+ raise exc .IncorrectTileUsage (
127
+ f"expected type hl.tile args to match, got { type (a )} and { type (b )} "
128
+ )
129
+ if len (a ) != len (b ):
130
+ raise exc .IncorrectTileUsage (
131
+ f"expected dims for hl.tile args to match, got { len (a )} and { len (b )} "
132
+ )
133
+ elif isinstance (a , (int , torch .SymInt , torch .Tensor )):
134
+ if not isinstance (b , (int , torch .SymInt , torch .Tensor )):
135
+ raise exc .IncorrectTileUsage (
136
+ f"expected type hl.tile args to match, got { type (a )} and { type (b )} "
137
+ )
138
+ else :
139
+ raise exc .IncorrectTileUsage (
140
+ f"expected type hl.tile args to be IntLike or list[IntLike], got { type (a )} "
141
+ )
142
+
143
+
144
+ def _normalize_begin_end (
145
+ begin_or_end : TypeInfo ,
146
+ end_or_none : TypeInfo | None ,
147
+ origin : Origin ,
148
+ ) -> tuple [TypeInfo , TypeInfo ]:
149
+ """Fill in defaults for begin if it is not provided."""
150
+ if _not_none (end_or_none ):
151
+ begin = begin_or_end
152
+ end = end_or_none
153
+ else :
154
+ try :
155
+ begin = TypeInfo .from_example (begin_or_end .tree_map (lambda n : 0 ), origin )
156
+ except NotImplementedError :
157
+ raise exc .TypePropagationError (
158
+ UnknownType (
159
+ origin ,
160
+ f"expected IntLike or list[IntLike], got { begin_or_end !s} " ,
161
+ chained_from = begin_or_end ,
162
+ )
163
+ ) from None
164
+ end = begin_or_end
165
+ return begin , end
166
+
167
+
90
168
@_decorators .type_propagation (tile )
91
169
def _ (
92
- sizes : TypeInfo , block_size : TypeInfo | None = None , * , origin : Origin
170
+ begin_or_end : TypeInfo ,
171
+ end_or_none : TypeInfo | None = None ,
172
+ / ,
173
+ block_size : TypeInfo | None = None ,
174
+ * ,
175
+ origin : Origin ,
93
176
) -> TypeInfo :
94
177
parent = ExtendedAST .current ()[- 2 ]
95
178
if not isinstance (parent , ast .For ):
96
179
raise exc .LoopFunctionNotInFor ("tile" )
180
+ begin , end = _normalize_begin_end (begin_or_end , end_or_none , origin = origin )
181
+ proxy_begin = _to_proxy (begin )
182
+ proxy_end = _to_proxy (end )
183
+ _check_matching (proxy_begin , proxy_end )
184
+ if _not_none (block_size ):
185
+ proxy_block_size = TileIndexProxy .tiles_to_sizes (_to_proxy (block_size ))
186
+ _check_matching (proxy_end , proxy_block_size )
187
+ else :
188
+ proxy_block_size = begin .tree_map (lambda n : None )
189
+
190
+ if unpack := not isinstance (proxy_end , (list , tuple )):
191
+ proxy_begin = [proxy_begin ]
192
+ proxy_end = [proxy_end ]
193
+ proxy_block_size = [proxy_block_size ]
194
+
97
195
if (
98
- block_size is None
99
- or block_size . is_literal ( )
100
- and block_size . as_literal () is None
196
+ all ( bs is None for bs in proxy_block_size )
197
+ and all ( isinstance ( s , ( int , torch . SymInt )) for s in proxy_begin )
198
+ and all ( isinstance ( s , ( int , torch . SymInt )) for s in proxy_end )
101
199
):
102
- result = _register_block_size_types (sizes , origin )
200
+ proxy_size = [e - b for b , e in zip (proxy_begin , proxy_end , strict = True )]
201
+ results = TileIndexType .allocate (proxy_size , origin )
103
202
else :
104
- try :
105
- proxy_sizes = sizes .proxy ()
106
- proxy_block_size = TileIndexProxy .tiles_to_sizes (block_size .proxy ())
107
- except NotImplementedError :
108
- raise exc .IncorrectTileUsage (
109
- f"expected int or list[int], got { sizes !s} and { block_size !s} "
110
- ) from None
111
- if isinstance (proxy_sizes , (list , tuple )):
112
- if not isinstance (proxy_block_size , (list , tuple )) or len (
113
- proxy_sizes
114
- ) != len (proxy_block_size ):
115
- raise exc .IncorrectTileUsage (
116
- f"expected dims for sizes and block_sizes to match, got { sizes !s} and { block_size !s} "
117
- )
118
- unpack = False
119
- else :
120
- if not isinstance (proxy_block_size , int | torch .SymInt ):
121
- raise exc .IncorrectTileUsage (
122
- f"expected type for sizes and block_sizes to match, got { sizes !s} and { block_size !s} "
123
- )
124
- proxy_sizes = [proxy_sizes ]
125
- proxy_block_size = [proxy_block_size ]
126
- unpack = True
203
+ # we must allocate the block sizes individually due to data dependent size or pre-allocated block sizes
204
+ # TODO(jansel): this flattens the structure of the config, which we should avoid
127
205
results = []
128
- for size , bs in zip (proxy_sizes , proxy_block_size , strict = True ):
206
+ for begin_part , end_part , bs in zip (
207
+ proxy_begin , proxy_end , proxy_block_size , strict = True
208
+ ):
209
+ size = end_part - begin_part
210
+ if isinstance (size , torch .Tensor ):
211
+ size = None # data dependent size
129
212
if bs is None :
130
213
results .append (TileIndexType .allocate ([size ], origin )[0 ])
131
214
elif isinstance (bs , int ):
@@ -138,59 +221,14 @@ def _(
138
221
results .append (TileIndexType .allocate_fixed (size , bs , origin ))
139
222
else :
140
223
results .append (TileIndexType (origin = origin , block_size_idx = index ))
141
- if unpack :
142
- (result ,) = results
143
- else :
144
- result = SequenceType (origin , results )
145
- return IterType (origin , result )
146
-
147
-
148
- def _register_block_size_types (sizes : TypeInfo , origin : Origin ) -> TypeInfo :
149
- if isinstance (sizes , SequenceType ):
150
- unpacked = sizes .unpack ()
224
+ CompileEnvironment .current ().block_sizes [index ].mark_alternate_size (
225
+ size
226
+ )
227
+ if unpack :
228
+ (result ,) = results
151
229
else :
152
- unpacked = [sizes ]
153
- has_data_dependency = False
154
- for size in unpacked :
155
- if isinstance (size , TensorType ) and size .origin .is_device ():
156
- has_data_dependency = True
157
- elif isinstance (size , (LiteralType , SymIntType )) and isinstance (
158
- size .proxy (), (int , torch .SymInt )
159
- ):
160
- pass
161
- else :
162
- raise exc .TypePropagationError (
163
- UnknownType (
164
- origin ,
165
- f"tile() expected int or list[int], got { size !s} " ,
166
- chained_from = size ,
167
- )
168
- )
169
- if has_data_dependency :
170
- # TODO(jansel): support flatten/reorder for data dependencies
171
- inner_types : list [TypeInfo ] = []
172
- for size in unpacked :
173
- if isinstance (size , TensorType ) and size .origin .is_device ():
174
- proxy = None
175
- else :
176
- proxy = size .proxy ()
177
- assert isinstance (proxy , (int , torch .SymInt ))
178
- inner_types .append (TileIndexType .allocate ([proxy ], origin )[0 ])
179
- if isinstance (sizes , SequenceType ):
180
- return SequenceType (
181
- origin = origin ,
182
- element_types = inner_types ,
183
- )
184
- assert len (inner_types ) == 1
185
- return inner_types [0 ]
186
- proxy_sizes = sizes .proxy ()
187
- if isinstance (proxy_sizes , (int , torch .SymInt )):
188
- return TileIndexType .allocate ([proxy_sizes ], origin )[0 ]
189
- return SequenceType (
190
- origin = origin ,
191
- # pyre-fixme[6]
192
- element_types = TileIndexType .allocate (proxy_sizes , origin ),
193
- )
230
+ result = SequenceType (origin , results )
231
+ return IterType (origin , result )
194
232
195
233
196
234
def _get_block_indices (type_info : TypeInfo ) -> list [int ]:
@@ -334,6 +372,17 @@ def register_block_size(size: int | Sequence[int]) -> TileOutput | Sequence[Tile
334
372
raise exc .NotInsideKernel
335
373
336
374
375
+ def _register_block_size_types (sizes : TypeInfo , origin : Origin ) -> TypeInfo :
376
+ proxy_sizes = sizes .proxy ()
377
+ if isinstance (proxy_sizes , (int , torch .SymInt )):
378
+ return TileIndexType .allocate ([proxy_sizes ], origin )[0 ]
379
+ return SequenceType (
380
+ origin = origin ,
381
+ # pyre-fixme[6]
382
+ element_types = TileIndexType .allocate (proxy_sizes , origin ),
383
+ )
384
+
385
+
337
386
@_decorators .type_propagation (register_block_size )
338
387
def _ (sizes : TypeInfo , * , origin : Origin ) -> TypeInfo :
339
388
return _register_block_size_types (sizes , origin )
0 commit comments