Skip to content

Commit b49bd50

Browse files
committed
Refactor reduction loop config spec
stack-info: PR: #128, branch: jansel/stack/24
1 parent fa08371 commit b49bd50

File tree

4 files changed

+51
-67
lines changed

4 files changed

+51
-67
lines changed

helion/_compiler/device_ir.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,14 @@ def build_rolled_reductions(self) -> None:
331331
allow_loop = allow_loop or reduction_info.used_rdim
332332
self.rolled_reductions.append(reduction_info)
333333
graph_to_info[graph_id] = reduction_info
334-
env.config_spec.reduction_loop_specs.append(
335-
ReductionLoopSpec(
336-
size_hint=rdim.size_hint(),
337-
# TODO(jansel): we should add support for rolling multiple dims at once
338-
allow_loop=allow_loop and first,
334+
if allow_loop and first:
335+
# TODO(jansel): we should add support for rolling multiple dims at once
336+
env.config_spec.reduction_loops.append(
337+
ReductionLoopSpec(
338+
block_id=rdim.block_size_idx,
339+
size_hint=rdim.size_hint(),
340+
)
339341
)
340-
)
341342
first = False
342343

343344
def __enter__(self) -> None:

helion/_compiler/tile_dispatch.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import collections
43
import functools
54
import operator
65
from typing import TYPE_CHECKING
@@ -99,18 +98,16 @@ def _add_loop_strategy(
9998
def _add_reduction_strategies(self, fn: DeviceFunction, config: Config) -> None:
10099
env = CompileEnvironment.current()
101100
rdims = [bs.block_size_idx for bs in env.block_sizes if bs.reduction]
102-
reduction_loops = collections.deque(config.reduction_loops)
103-
for rdim_index, rdim_spec in zip(
104-
rdims, env.config_spec.reduction_loop_specs, strict=True
105-
):
106-
reduction_loop = reduction_loops.popleft() if rdim_spec.allow_loop else None
101+
for block_id in rdims:
102+
reduction_loop = env.config_spec.reduction_loops.config_get(
103+
config.reduction_loops, block_id, None
104+
)
107105
if reduction_loop is None:
108-
strategy: TileStrategy = PersistentReductionStrategy(fn, rdim_index)
106+
strategy: TileStrategy = PersistentReductionStrategy(fn, block_id)
109107
else:
110-
strategy = LoopedReductionStrategy(fn, rdim_index, reduction_loop)
108+
strategy = LoopedReductionStrategy(fn, block_id, reduction_loop)
111109
self.strategies.append(strategy)
112-
self.block_indices_to_strategy[(rdim_index,)] = strategy
113-
assert not reduction_loops
110+
self.block_indices_to_strategy[(block_id,)] = strategy
114111

115112
def codegen_grid(self, state: CodegenState, block_indices: list[int]) -> None:
116113
strategy = self.block_indices_to_strategy[tuple(block_indices)]

helion/autotuner/config_spec.py

Lines changed: 34 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def _fragment(self, base: ConfigSpec) -> ConfigSpecFragment:
7070
"""Return the fragment used for autotunging for this item."""
7171
raise NotImplementedError
7272

73+
def _flat_config(
74+
self, base: ConfigSpec, fn: Callable[[ConfigSpecFragment], object]
75+
) -> object:
76+
return fn(self._fragment(base))
77+
7378

7479
_BlockIdItemT = TypeVar("_BlockIdItemT", bound=_BlockIdItem)
7580

@@ -153,7 +158,7 @@ def _flat_config(
153158
self, base: ConfigSpec, fn: Callable[[ConfigSpecFragment], object]
154159
) -> list[object]:
155160
"""Map a flattened version of the config using the given function."""
156-
return [fn(spec._fragment(base)) for spec in self._data]
161+
return [spec._flat_config(base, fn) for spec in self._data]
157162

158163
def _normalize(
159164
self, name: str, values: object, *, flatten: bool = False
@@ -219,9 +224,8 @@ class ConfigSpec:
219224
flatten_loops: BlockIdSequence[FlattenLoopSpec] = dataclasses.field(
220225
default_factory=BlockIdSequence
221226
)
222-
# TODO(jansel): convert this to a BlockIdSequence[ReductionLoopSpec]
223-
reduction_loop_specs: list[ReductionLoopSpec] = dataclasses.field(
224-
default_factory=list
227+
reduction_loops: BlockIdSequence[ReductionLoopSpec] = dataclasses.field(
228+
default_factory=BlockIdSequence
225229
)
226230
allow_use_yz_grid: bool | None = None
227231

@@ -254,15 +258,12 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
254258
("flatten_loops", self.flatten_loops, True),
255259
("l2_groupings", self.l2_groupings, True),
256260
("loop_orders", self.loop_orders, False),
261+
("reduction_loops", self.reduction_loops, True),
257262
]:
258263
config[name] = mapping._normalize(
259264
name, config.get(name, ()), flatten=flatten
260265
)
261266

262-
config["reduction_loops"] = self.normalize_reduction_loops(
263-
config.get("reduction_loops", None)
264-
)
265-
266267
for name in ("loop_orders", "l2_groupings", "flatten_loops", "reduction_loops"):
267268
if not config[name]:
268269
config.pop(name)
@@ -278,22 +279,6 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
278279
if invalid_keys := ({*config} - VALID_KEYS):
279280
raise InvalidConfig(f"Invalid config keys {sorted(invalid_keys)!r}")
280281

281-
def normalize_reduction_loops(self, reduction_loops: object) -> list[int | None]:
282-
assert isinstance(reduction_loops, (list, tuple, type(None), int))
283-
loops = [spec for spec in self.reduction_loop_specs if spec.allow_loop]
284-
if reduction_loops is None:
285-
reduction_loops = [None for _ in loops]
286-
elif isinstance(reduction_loops, int):
287-
reduction_loops = [reduction_loops]
288-
if len(reduction_loops) != len(loops):
289-
raise InvalidConfig(
290-
f"Invalid number of reduction loops, expected {len(loops)} got {len(reduction_loops)}"
291-
)
292-
return [
293-
spec.normalize(value)
294-
for spec, value in zip(loops, reduction_loops, strict=True)
295-
]
296-
297282
def default_config(self) -> helion.Config:
298283
return self.flat_config(lambda x: x.default())
299284

@@ -304,11 +289,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
304289
"loop_orders": self.loop_orders._flat_config(self, fn),
305290
"flatten_loops": self.flatten_loops._flat_config(self, fn),
306291
"l2_groupings": self.l2_groupings._flat_config(self, fn),
307-
"reduction_loops": [
308-
spec.flat_reduction_loop(fn)
309-
for spec in self.reduction_loop_specs
310-
if spec.allow_loop
311-
],
292+
"reduction_loops": self.reduction_loops._flat_config(self, fn),
312293
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
313294
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
314295
"indexing": fn(
@@ -354,7 +335,7 @@ def _fill_missing(self) -> list[int]:
354335

355336

356337
class _PowerOfTwoBlockIdItem(_BlockIdItem):
357-
def _normalize(self, name: str, value: object) -> int:
338+
def _normalize(self, name: str, value: object) -> int | None:
358339
try:
359340
return assert_integer_power_of_two(value)
360341
except InvalidConfig:
@@ -413,7 +394,7 @@ def update_hint(self, value: int) -> None:
413394
def _fragment(self, base: ConfigSpec) -> BlockSizeFragment:
414395
total_ndim = len(base.block_sizes)
415396
reduction_numel = _product(
416-
[next_power_of_2(spec.size_hint) for spec in base.reduction_loop_specs]
397+
[next_power_of_2(spec.size_hint) for spec in base.reduction_loops]
417398
)
418399
if total_ndim <= 1 and reduction_numel <= 1:
419400
default = 1024
@@ -443,31 +424,36 @@ def _fill_missing(self) -> bool:
443424
return False
444425

445426

446-
@dataclasses.dataclass
447-
class ReductionLoopSpec:
448-
size_hint: int
449-
allow_loop: bool
450-
451-
def normalize(self, value: int | None) -> int | None:
452-
if value is None:
453-
return None
454-
assert_integer_power_of_two(value)
455-
if value < 0 or value >= next_power_of_2(self.size_hint):
456-
raise InvalidConfig(
457-
f"Invalid reduction loop value {value!r}, expected 0 to {next_power_of_2(self.size_hint)}"
458-
)
459-
return value
427+
class ReductionLoopSpec(_PowerOfTwoBlockIdItem):
428+
def __init__(
429+
self,
430+
*,
431+
block_id: int,
432+
size_hint: int,
433+
) -> None:
434+
super().__init__([block_id])
435+
self.size_hint = size_hint
460436

461-
def flat_reduction_loop(self, fn: Callable[[ConfigSpecFragment], object]) -> object:
462-
assert self.allow_loop
437+
def _flat_config(
438+
self, base: ConfigSpec, fn: Callable[[ConfigSpecFragment], object]
439+
) -> int | None:
463440
low = 8 # TODO(jansel): is smaller needed?
464441
high = next_power_of_2(self.size_hint)
465442
default = min(high, 4096)
466443
value = fn(BlockSizeFragment(low, high, default))
467-
if value == high:
444+
assert isinstance(value, int)
445+
if value >= self.size_hint:
468446
return None # max size becomes persistent reduction
469447
return value
470448

449+
def _normalize(self, name: str, value: object) -> int | None:
450+
if value is None:
451+
return None
452+
return super()._normalize(name, value)
453+
454+
def _fill_missing(self) -> None:
455+
return None
456+
471457

472458
def _product(seq: Sequence[int]) -> int:
473459
"""Return the product of the elements in the sequence."""

test/test_specialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def fn(
157157
x = torch.randn([512, 512], device=DEVICE)
158158
code, result = code_and_output(fn, (x,), block_size=32)
159159
torch.testing.assert_close(result, x + 1)
160-
self.assertFalse(fn.bind((x,)).config_spec.reduction_loop_specs[0].allow_loop)
160+
self.assertEqual(len(fn.bind((x,)).config_spec.reduction_loops), 0)
161161
self.assertExpectedInline(
162162
code,
163163
"""\
@@ -214,7 +214,7 @@ def fn(
214214
x = torch.randn([500, 500], device=DEVICE)
215215
code, result = code_and_output(fn, (x,), block_size=32)
216216
torch.testing.assert_close(result, x + 1)
217-
self.assertFalse(fn.bind((x,)).config_spec.reduction_loop_specs[0].allow_loop)
217+
self.assertEqual(len(fn.bind((x,)).config_spec.reduction_loops), 0)
218218
self.assertIs(
219219
fn.bind((x,)),
220220
fn.bind((torch.zeros_like(x),)),
@@ -278,7 +278,7 @@ def fn(
278278
x = torch.randn([500, 500], device=DEVICE)
279279
code, result = code_and_output(fn, (x,), block_size=32)
280280
torch.testing.assert_close(result, x.sum(-1))
281-
self.assertTrue(fn.bind((x,)).config_spec.reduction_loop_specs[0].allow_loop)
281+
self.assertEqual(len(fn.bind((x,)).config_spec.reduction_loops), 1)
282282
self.assertExpectedInline(
283283
code,
284284
"""\

0 commit comments

Comments
 (0)