@@ -70,6 +70,11 @@ def _fragment(self, base: ConfigSpec) -> ConfigSpecFragment:
70
70
"""Return the fragment used for autotunging for this item."""
71
71
raise NotImplementedError
72
72
73
+ def _flat_config (
74
+ self , base : ConfigSpec , fn : Callable [[ConfigSpecFragment ], object ]
75
+ ) -> object :
76
+ return fn (self ._fragment (base ))
77
+
73
78
74
79
_BlockIdItemT = TypeVar ("_BlockIdItemT" , bound = _BlockIdItem )
75
80
@@ -153,7 +158,7 @@ def _flat_config(
153
158
self , base : ConfigSpec , fn : Callable [[ConfigSpecFragment ], object ]
154
159
) -> list [object ]:
155
160
"""Map a flattened version of the config using the given function."""
156
- return [fn ( spec ._fragment (base ) ) for spec in self ._data ]
161
+ return [spec ._flat_config (base , fn ) for spec in self ._data ]
157
162
158
163
def _normalize (
159
164
self , name : str , values : object , * , flatten : bool = False
@@ -219,9 +224,8 @@ class ConfigSpec:
219
224
flatten_loops : BlockIdSequence [FlattenLoopSpec ] = dataclasses .field (
220
225
default_factory = BlockIdSequence
221
226
)
222
- # TODO(jansel): convert this to a BlockIdSequence[ReductionLoopSpec]
223
- reduction_loop_specs : list [ReductionLoopSpec ] = dataclasses .field (
224
- default_factory = list
227
+ reduction_loops : BlockIdSequence [ReductionLoopSpec ] = dataclasses .field (
228
+ default_factory = BlockIdSequence
225
229
)
226
230
allow_use_yz_grid : bool | None = None
227
231
@@ -254,15 +258,12 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
254
258
("flatten_loops" , self .flatten_loops , True ),
255
259
("l2_groupings" , self .l2_groupings , True ),
256
260
("loop_orders" , self .loop_orders , False ),
261
+ ("reduction_loops" , self .reduction_loops , True ),
257
262
]:
258
263
config [name ] = mapping ._normalize (
259
264
name , config .get (name , ()), flatten = flatten
260
265
)
261
266
262
- config ["reduction_loops" ] = self .normalize_reduction_loops (
263
- config .get ("reduction_loops" , None )
264
- )
265
-
266
267
for name in ("loop_orders" , "l2_groupings" , "flatten_loops" , "reduction_loops" ):
267
268
if not config [name ]:
268
269
config .pop (name )
@@ -278,22 +279,6 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
278
279
if invalid_keys := ({* config } - VALID_KEYS ):
279
280
raise InvalidConfig (f"Invalid config keys { sorted (invalid_keys )!r} " )
280
281
281
- def normalize_reduction_loops (self , reduction_loops : object ) -> list [int | None ]:
282
- assert isinstance (reduction_loops , (list , tuple , type (None ), int ))
283
- loops = [spec for spec in self .reduction_loop_specs if spec .allow_loop ]
284
- if reduction_loops is None :
285
- reduction_loops = [None for _ in loops ]
286
- elif isinstance (reduction_loops , int ):
287
- reduction_loops = [reduction_loops ]
288
- if len (reduction_loops ) != len (loops ):
289
- raise InvalidConfig (
290
- f"Invalid number of reduction loops, expected { len (loops )} got { len (reduction_loops )} "
291
- )
292
- return [
293
- spec .normalize (value )
294
- for spec , value in zip (loops , reduction_loops , strict = True )
295
- ]
296
-
297
282
def default_config (self ) -> helion .Config :
298
283
return self .flat_config (lambda x : x .default ())
299
284
@@ -304,11 +289,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
304
289
"loop_orders" : self .loop_orders ._flat_config (self , fn ),
305
290
"flatten_loops" : self .flatten_loops ._flat_config (self , fn ),
306
291
"l2_groupings" : self .l2_groupings ._flat_config (self , fn ),
307
- "reduction_loops" : [
308
- spec .flat_reduction_loop (fn )
309
- for spec in self .reduction_loop_specs
310
- if spec .allow_loop
311
- ],
292
+ "reduction_loops" : self .reduction_loops ._flat_config (self , fn ),
312
293
"num_warps" : fn (NumWarpsFragment (1 , 32 , DEFAULT_NUM_WARPS )),
313
294
"num_stages" : fn (IntegerFragment (1 , 8 , DEFAULT_NUM_STAGES )),
314
295
"indexing" : fn (
@@ -354,7 +335,7 @@ def _fill_missing(self) -> list[int]:
354
335
355
336
356
337
class _PowerOfTwoBlockIdItem (_BlockIdItem ):
357
- def _normalize (self , name : str , value : object ) -> int :
338
+ def _normalize (self , name : str , value : object ) -> int | None :
358
339
try :
359
340
return assert_integer_power_of_two (value )
360
341
except InvalidConfig :
@@ -413,7 +394,7 @@ def update_hint(self, value: int) -> None:
413
394
def _fragment (self , base : ConfigSpec ) -> BlockSizeFragment :
414
395
total_ndim = len (base .block_sizes )
415
396
reduction_numel = _product (
416
- [next_power_of_2 (spec .size_hint ) for spec in base .reduction_loop_specs ]
397
+ [next_power_of_2 (spec .size_hint ) for spec in base .reduction_loops ]
417
398
)
418
399
if total_ndim <= 1 and reduction_numel <= 1 :
419
400
default = 1024
@@ -443,31 +424,36 @@ def _fill_missing(self) -> bool:
443
424
return False
444
425
445
426
446
- @dataclasses .dataclass
447
- class ReductionLoopSpec :
448
- size_hint : int
449
- allow_loop : bool
450
-
451
- def normalize (self , value : int | None ) -> int | None :
452
- if value is None :
453
- return None
454
- assert_integer_power_of_two (value )
455
- if value < 0 or value >= next_power_of_2 (self .size_hint ):
456
- raise InvalidConfig (
457
- f"Invalid reduction loop value { value !r} , expected 0 to { next_power_of_2 (self .size_hint )} "
458
- )
459
- return value
427
+ class ReductionLoopSpec (_PowerOfTwoBlockIdItem ):
428
+ def __init__ (
429
+ self ,
430
+ * ,
431
+ block_id : int ,
432
+ size_hint : int ,
433
+ ) -> None :
434
+ super ().__init__ ([block_id ])
435
+ self .size_hint = size_hint
460
436
461
- def flat_reduction_loop (self , fn : Callable [[ConfigSpecFragment ], object ]) -> object :
462
- assert self .allow_loop
437
+ def _flat_config (
438
+ self , base : ConfigSpec , fn : Callable [[ConfigSpecFragment ], object ]
439
+ ) -> int | None :
463
440
low = 8 # TODO(jansel): is smaller needed?
464
441
high = next_power_of_2 (self .size_hint )
465
442
default = min (high , 4096 )
466
443
value = fn (BlockSizeFragment (low , high , default ))
467
- if value == high :
444
+ assert isinstance (value , int )
445
+ if value >= self .size_hint :
468
446
return None # max size becomes persistent reduction
469
447
return value
470
448
449
+ def _normalize (self , name : str , value : object ) -> int | None :
450
+ if value is None :
451
+ return None
452
+ return super ()._normalize (name , value )
453
+
454
+ def _fill_missing (self ) -> None :
455
+ return None
456
+
471
457
472
458
def _product (seq : Sequence [int ]) -> int :
473
459
"""Return the product of the elements in the sequence."""
0 commit comments