Skip to content

Commit f6b30d1

Browse files
authored
Rename block_indices to block_ids (#135)
Some minor refactoring to make naming more consistent
1 parent e7ace97 commit f6b30d1

File tree

7 files changed

+97
-111
lines changed

7 files changed

+97
-111
lines changed

helion/_compiler/device_ir.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def kwargs(self) -> dict[str, object]:
188188

189189
@dataclasses.dataclass
190190
class ForLoopGraphInfo(NodeArgsGraphInfo):
191-
block_indices: list[int]
191+
block_ids: list[int]
192192

193193
@property
194194
def name(self) -> str:
@@ -197,7 +197,7 @@ def name(self) -> str:
197197
def kwargs(self) -> dict[str, object]:
198198
return {
199199
**super().kwargs(),
200-
"block_indices": [*self.block_indices],
200+
"block_ids": [*self.block_ids],
201201
}
202202

203203
def codegen(self, state: CodegenState) -> list[object]:
@@ -206,7 +206,7 @@ def codegen(self, state: CodegenState) -> list[object]:
206206
assert all(isinstance(x, ast.AST) for x in args)
207207
with state.codegen.add_device_loop(
208208
state.device_function.tile_strategy.codegen_device_loop(
209-
state, self.block_indices
209+
state, self.block_ids
210210
)
211211
):
212212
return codegen_call_with_graph(
@@ -238,7 +238,7 @@ def codegen(self, state: CodegenState) -> list[object]:
238238

239239

240240
class RolledReductionInfo(NamedTuple):
241-
rolled_block_indices: list[int]
241+
rolled_block_ids: list[int]
242242
original_graph_id: int
243243
new_graph_id: int | None
244244
used_rdim: bool
@@ -251,7 +251,7 @@ def __init__(self) -> None:
251251
self.graphs: list[GraphInfo] = []
252252
self.root_id: int | None = None
253253
self.rolled_reductions: list[RolledReductionInfo] = []
254-
self.grid_block_indices: list[list[int]] = []
254+
self.grid_block_ids: list[list[int]] = []
255255

256256
def get_root(self, config: Config) -> torch.fx.Graph:
257257
""" " If we are using a rolled reduction, return the rolled reduction graph otherwise
@@ -296,7 +296,7 @@ def add_reduction_loop_graph(
296296
return self.add_graph(
297297
graph,
298298
graph_info_cls=ReductionLoopGraphInfo,
299-
block_indices=[block_index],
299+
block_ids=[block_index],
300300
node_args=node_args,
301301
)
302302

@@ -321,7 +321,7 @@ def build_rolled_reductions(self) -> None:
321321
new_graph, type(graph_info), **graph_info.kwargs()
322322
)
323323
reduction_info = RolledReductionInfo(
324-
rolled_block_indices=[rdim.block_id],
324+
rolled_block_ids=[rdim.block_id],
325325
original_graph_id=graph_id,
326326
new_graph_id=new_graph_id,
327327
used_rdim=len(roller.graphs_added) > 0,
@@ -545,7 +545,7 @@ def run_subgraph(*args: object) -> list[object]:
545545
graph_idx = self.device_ir.add_graph(
546546
graph,
547547
ForLoopGraphInfo,
548-
block_indices=[x.block_id for x in iter_vars],
548+
block_ids=[x.block_id for x in iter_vars],
549549
node_args=inputs.get_node_args(tracer),
550550
)
551551
args = (
@@ -826,10 +826,10 @@ def visit_For(self, node: ast.For) -> None:
826826
assert isinstance(iter_type, IterType)
827827
inner = iter_type.inner
828828
if isinstance(inner, SequenceType):
829-
block_indices = [x.block_id for x in inner.unpack()]
829+
block_ids = [x.block_id for x in inner.unpack()]
830830
else:
831-
block_indices = [inner.block_id]
832-
self.device_ir.grid_block_indices.append(block_indices)
831+
block_ids = [inner.block_id]
832+
self.device_ir.grid_block_ids.append(block_ids)
833833
else:
834834
self.generic_visit(node)
835835

helion/_compiler/generate_ast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,22 @@ def set_on_device(self) -> Iterator[None]:
9797
@contextlib.contextmanager
9898
def add_device_loop(self, device_loop: DeviceLoopState) -> Iterator[None]:
9999
with self.set_statements(device_loop.inner_statements):
100-
for idx in device_loop.block_indices:
100+
for idx in device_loop.block_ids:
101101
active_loops = self.active_device_loops[idx]
102102
active_loops.append(device_loop)
103103
if len(active_loops) > 1:
104104
raise exc.NestedDeviceLoopsConflict
105105
try:
106106
yield
107107
finally:
108-
for idx in device_loop.block_indices:
108+
for idx in device_loop.block_ids:
109109
self.active_device_loops[idx].pop()
110110
self.statements_stack[-1].extend(device_loop.outer_prefix)
111111
self.add_statement(device_loop.for_node)
112112
self.statements_stack[-1].extend(device_loop.outer_suffix)
113113

114114
def set_active_loops(self, device_grid: DeviceLoopOrGridState) -> None:
115-
for idx in device_grid.block_indices:
115+
for idx in device_grid.block_ids:
116116
self.active_device_loops[idx] = [device_grid]
117117

118118
def generic_visit(self, node: ast.AST) -> ast.AST:

helion/_compiler/reduction_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
) -> None:
4040
super().__init__(
4141
fn=fn,
42-
block_indices=[block_index],
42+
block_ids=[block_index],
4343
)
4444
self._mask_var = mask_var
4545
if block_size_var is not None:
@@ -51,7 +51,7 @@ def mask_var(self, block_idx: int) -> str | None:
5151

5252
@property
5353
def block_index(self) -> int:
54-
return self.block_indices[0]
54+
return self.block_ids[0]
5555

5656
def user_size(self, block_index: int) -> sympy.Expr:
5757
return CompileEnvironment.current().block_sizes[block_index].numel

helion/_compiler/tile_dispatch.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,37 +41,37 @@ def __init__(
4141
) -> None:
4242
super().__init__()
4343
self.strategies: list[TileStrategy] = []
44-
self.block_indices_to_strategy: dict[tuple[int, ...], TileStrategy] = {}
44+
self.block_id_to_strategy: dict[tuple[int, ...], TileStrategy] = {}
4545
self._add_loop_strategies(fn, config)
4646
self._add_reduction_strategies(fn, config)
4747

4848
def _add_loop_strategies(self, fn: DeviceFunction, config: Config) -> None:
4949
device_ir = HostFunction.current().device_ir
50-
for block_indices in device_ir.grid_block_indices:
51-
self._add_loop_strategy(block_indices, fn, config)
50+
for block_ids in device_ir.grid_block_ids:
51+
self._add_loop_strategy(block_ids, fn, config)
5252
for graph in device_ir.graphs:
5353
if isinstance(graph, ForLoopGraphInfo) and not isinstance(
5454
graph, ReductionLoopGraphInfo
5555
):
56-
block_indices = [*graph.block_indices]
57-
self._add_loop_strategy(block_indices, fn, config)
56+
block_ids = [*graph.block_ids]
57+
self._add_loop_strategy(block_ids, fn, config)
5858

5959
def _add_loop_strategy(
60-
self, block_indices: list[int], fn: DeviceFunction, config: Config
60+
self, block_ids: list[int], fn: DeviceFunction, config: Config
6161
) -> None:
6262
env = CompileEnvironment.current()
63-
block_size_infos = [env.block_sizes[i] for i in block_indices]
63+
block_size_infos = [env.block_sizes[i] for i in block_ids]
6464
loop_order = env.config_spec.loop_orders.config_get(
65-
config.loop_orders, block_indices[0]
66-
) or [*range(len(block_indices))]
65+
config.loop_orders, block_ids[0]
66+
) or [*range(len(block_ids))]
6767
l2_grouping = env.config_spec.l2_groupings.config_get(
68-
config.l2_groupings, block_indices[0], 1
68+
config.l2_groupings, block_ids[0], 1
6969
)
7070

7171
if block_size_infos[0].is_grid():
7272
strategy: TileStrategy = NDGridTileStrategy(
7373
fn,
74-
block_indices,
74+
block_ids,
7575
loop_order=loop_order,
7676
)
7777
elif block_size_infos[0].is_flattened(config):
@@ -80,20 +80,20 @@ def _add_loop_strategy(
8080
)
8181
strategy: TileStrategy = FlattenedTileStrategy(
8282
fn,
83-
block_indices,
83+
block_ids,
8484
block_size=block_size,
8585
loop_order=loop_order,
8686
)
8787
else:
8888
strategy = NDTileStrategy(
8989
fn,
90-
block_indices,
90+
block_ids,
9191
block_size=[bs.from_config_assert(config) for bs in block_size_infos],
9292
loop_order=loop_order,
9393
l2_grouping=l2_grouping,
9494
)
9595
self.strategies.append(strategy)
96-
self.block_indices_to_strategy[tuple(block_indices)] = strategy
96+
self.block_id_to_strategy[tuple(block_ids)] = strategy
9797

9898
def _add_reduction_strategies(self, fn: DeviceFunction, config: Config) -> None:
9999
env = CompileEnvironment.current()
@@ -107,20 +107,20 @@ def _add_reduction_strategies(self, fn: DeviceFunction, config: Config) -> None:
107107
else:
108108
strategy = LoopedReductionStrategy(fn, block_id, reduction_loop)
109109
self.strategies.append(strategy)
110-
self.block_indices_to_strategy[(block_id,)] = strategy
110+
self.block_id_to_strategy[(block_id,)] = strategy
111111

112-
def codegen_grid(self, state: CodegenState, block_indices: list[int]) -> None:
113-
strategy = self.block_indices_to_strategy[tuple(block_indices)]
112+
def codegen_grid(self, state: CodegenState, block_ids: list[int]) -> None:
113+
strategy = self.block_id_to_strategy[tuple(block_ids)]
114114
strategy.codegen_grid(state)
115115
for other_strategy in self.strategies:
116116
if other_strategy is not strategy:
117117
other_strategy.codegen_preamble(state)
118118
state.codegen.set_active_loops(DeviceGridState(strategy))
119119

120120
def codegen_device_loop(
121-
self, state: CodegenState, block_indices: list[int]
121+
self, state: CodegenState, block_ids: list[int]
122122
) -> DeviceLoopState:
123-
strategy = self.block_indices_to_strategy[tuple(block_indices)]
123+
strategy = self.block_id_to_strategy[tuple(block_ids)]
124124
return strategy.codegen_device_loop(state)
125125

126126
def _compact_shape(self, shapes: ShapeLike) -> list[CompactedShape]:
@@ -161,14 +161,14 @@ def expand_str(self, shape: ShapeLike, i: int) -> str:
161161
return f"[{', '.join(result)}]"
162162

163163
def get_reduction_strategy(self, block_idx: int) -> ReductionStrategy:
164-
strategy = self.block_indices_to_strategy[(block_idx,)]
164+
strategy = self.block_id_to_strategy[(block_idx,)]
165165
assert isinstance(strategy, ReductionStrategy)
166166
return strategy
167167

168168
def user_size(self, block_index: int) -> sympy.Expr:
169169
"""The user-visible size of the block index."""
170170
# This only does something special for reduction loops, only need to check for 1D loop
171-
strategy = self.block_indices_to_strategy.get((block_index,))
171+
strategy = self.block_id_to_strategy.get((block_index,))
172172
if strategy is None:
173173
return CompileEnvironment.current().block_sizes[block_index].symbol()
174174
return strategy.user_size(block_index)

0 commit comments

Comments
 (0)