Skip to content

Commit 07e8cd3

Browse files
committed
Rename block_size_idx to block_id
Some minor refactoring to make naming more consistent
1 parent 750a34f commit 07e8cd3

11 files changed

+55
-57
lines changed

helion/_compiler/compile_environment.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def allocate_block_size(
101101
idx = len(self.block_sizes)
102102
self.block_sizes.append(
103103
info := BlockSizeInfo(
104-
block_size_idx=idx,
104+
block_id=idx,
105105
size=size,
106106
var=self.create_block_var(
107107
f"block_size_{idx}" if not reduction else f"rdim_{idx}",
@@ -310,7 +310,7 @@ class BlockSizeInfo:
310310
Used to track the block size for a given dimension.
311311
"""
312312

313-
block_size_idx: int
313+
block_id: int
314314
size: torch.SymInt | int | AutoSize | None
315315
var: torch.SymInt
316316
reduction: bool
@@ -348,7 +348,7 @@ def mark_alternate_size(self, size: torch.SymInt | int | None) -> None:
348348
with contextlib.suppress(KeyError):
349349
# update the size hint now that we know the size
350350
env.config_spec.block_sizes.block_id_lookup(
351-
self.block_size_idx
351+
self.block_id
352352
).update_hint(env.size_hint(size))
353353
elif size is None or self.size is None or self.size != size:
354354
self.size = None
@@ -357,7 +357,7 @@ def symbol(self) -> sympy.Symbol:
357357
return self.var._sympy_()
358358

359359
def from_config(self, config: Config) -> int | torch.SymInt | None:
360-
return self.block_size_source.from_config(config, self.block_size_idx)
360+
return self.block_size_source.from_config(config, self.block_id)
361361

362362
def from_config_assert(self, config: Config) -> int | torch.SymInt:
363363
val = self.from_config(config)
@@ -366,19 +366,17 @@ def from_config_assert(self, config: Config) -> int | torch.SymInt:
366366

367367
def is_flattened(self, config: Config) -> bool:
368368
spec = CompileEnvironment.current().config_spec
369-
return spec.flatten_loops.config_get(
370-
config.flatten_loops, self.block_size_idx, False
371-
)
369+
return spec.flatten_loops.config_get(config.flatten_loops, self.block_id, False)
372370

373371
def is_grid(self) -> bool:
374372
return self.block_size_source.is_grid()
375373

376374
def update_min_block(self, value: int, *, allow_flattened: bool = True) -> None:
377375
spec = CompileEnvironment.current().config_spec
378376
if not allow_flattened:
379-
spec.flatten_loops.disable_block_id(self.block_size_idx)
377+
spec.flatten_loops.disable_block_id(self.block_id)
380378
with contextlib.suppress(KeyError):
381-
spec.block_sizes.block_id_lookup(self.block_size_idx).update_min(value)
379+
spec.block_sizes.block_id_lookup(self.block_id).update_min(value)
382380

383381

384382
class BlockSizeSource:

helion/_compiler/device_function.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def __init__(self, name: str, config: Config) -> None:
158158
self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config)
159159
self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config)
160160

161-
def block_size_var(self, block_size_idx: int) -> str | None:
162-
return self.block_size_var_cache.get((block_size_idx,))
161+
def block_size_var(self, block_id: int) -> str | None:
162+
return self.block_size_var_cache.get((block_id,))
163163

164164
def merge_variable_names(self, a: str, b: str) -> None:
165165
name_group = [
@@ -197,7 +197,7 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str:
197197
)
198198
return arg.name
199199
if isinstance(origin.origin, BlockSizeOrigin):
200-
result = self.block_size_var(origin.origin.block_size_idx)
200+
result = self.block_size_var(origin.origin.block_id)
201201
assert result is not None
202202
return result
203203
return self.expr_arg(expr, origin.origin).name

helion/_compiler/device_ir.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def build_rolled_reductions(self) -> None:
321321
new_graph, type(graph_info), **graph_info.kwargs()
322322
)
323323
reduction_info = RolledReductionInfo(
324-
rolled_block_indices=[rdim.block_size_idx],
324+
rolled_block_indices=[rdim.block_id],
325325
original_graph_id=graph_id,
326326
new_graph_id=new_graph_id,
327327
used_rdim=len(roller.graphs_added) > 0,
@@ -335,7 +335,7 @@ def build_rolled_reductions(self) -> None:
335335
# TODO(jansel): we should add support for rolling multiple dims at once
336336
env.config_spec.reduction_loops.append(
337337
ReductionLoopSpec(
338-
block_id=rdim.block_size_idx,
338+
block_id=rdim.block_id,
339339
size_hint=rdim.size_hint(),
340340
)
341341
)
@@ -545,7 +545,7 @@ def run_subgraph(*args: object) -> list[object]:
545545
graph_idx = self.device_ir.add_graph(
546546
graph,
547547
ForLoopGraphInfo,
548-
block_indices=[x.block_size_idx for x in iter_vars],
548+
block_indices=[x.block_id for x in iter_vars],
549549
node_args=inputs.get_node_args(tracer),
550550
)
551551
args = (
@@ -826,9 +826,9 @@ def visit_For(self, node: ast.For) -> None:
826826
assert isinstance(iter_type, IterType)
827827
inner = iter_type.inner
828828
if isinstance(inner, SequenceType):
829-
block_indices = [x.block_size_idx for x in inner.unpack()]
829+
block_indices = [x.block_id for x in inner.unpack()]
830830
else:
831-
block_indices = [inner.block_size_idx]
831+
block_indices = [inner.block_id]
832832
self.device_ir.grid_block_indices.append(block_indices)
833833
else:
834834
self.generic_visit(node)

helion/_compiler/generate_ast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
212212
if self.on_device:
213213
pass
214214
elif isinstance(type_info := node._type_info, TileIndexType):
215-
block_info = env.block_sizes[type_info.block_size_idx]
215+
block_info = env.block_sizes[type_info.block_id]
216216
return expr_from_string(
217217
self.host_function.literal_expr(
218218
block_info.from_config(self.device_function.config)
@@ -221,7 +221,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
221221
elif isinstance(type_info, SequenceType):
222222
values = type_info.unpack()
223223
if all(isinstance(x, TileIndexType) for x in values):
224-
block_infos = [env.block_sizes[x.block_size_idx] for x in values]
224+
block_infos = [env.block_sizes[x.block_id] for x in values]
225225
return expr_from_string(
226226
self.host_function.literal_expr(
227227
[

helion/_compiler/indexing_strategy.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def compute_shape(
220220
if origin and isinstance(origin.origin, BlockSizeOrigin):
221221
if (
222222
CompileEnvironment.current()
223-
.block_sizes[origin.origin.block_size_idx]
223+
.block_sizes[origin.origin.block_id]
224224
.is_grid()
225225
):
226226
pass
@@ -272,15 +272,15 @@ def create(
272272
if isinstance(symbol, sympy.Symbol):
273273
origin = HostFunction.current().expr_to_origin.get(symbol)
274274
if origin and isinstance(origin.origin, BlockSizeOrigin):
275-
index_var = state.codegen.index_var(origin.origin.block_size_idx)
276-
if env.block_sizes[origin.origin.block_size_idx].is_grid():
275+
index_var = state.codegen.index_var(origin.origin.block_id)
276+
if env.block_sizes[origin.origin.block_id].is_grid():
277277
index_values.append(index_var)
278278
continue
279279
expand = tile_strategy.expand_str(output_size, output_idx)
280280
i = len(index_values)
281281
index_values.append(f"({index_var}){expand}")
282282
if (
283-
mask := state.codegen.mask_var(origin.origin.block_size_idx)
283+
mask := state.codegen.mask_var(origin.origin.block_id)
284284
) and fake_value.size(i) != 1:
285285
mask_values.setdefault(f"({mask}){expand}")
286286
output_idx += 1
@@ -293,7 +293,7 @@ def create(
293293
size = fake_value.size(len(index_values))
294294
if size != 1:
295295
rdim = env.allocate_reduction_dimension(size)
296-
block_idx = rdim.block_size_idx
296+
block_idx = rdim.block_id
297297
index_var = state.codegen.index_var(block_idx)
298298
index_values.append(f"({index_var}){expand}")
299299
if mask := state.codegen.mask_var(block_idx):
@@ -451,7 +451,7 @@ def is_supported(
451451
if isinstance(symbol, sympy.Symbol):
452452
origin = HostFunction.current().expr_to_origin.get(symbol)
453453
if origin and isinstance(origin.origin, BlockSizeOrigin):
454-
block_index = origin.origin.block_size_idx
454+
block_index = origin.origin.block_id
455455
try:
456456
state.codegen.offset_var(block_index)
457457
except NotImplementedError:
@@ -507,7 +507,7 @@ def create(
507507
if origin and isinstance(origin.origin, BlockSizeOrigin):
508508
if fake_value.size(len(res.offsets)) != 1:
509509
res.offsets.append(
510-
state.codegen.offset_var(origin.origin.block_size_idx)
510+
state.codegen.offset_var(origin.origin.block_id)
511511
)
512512
res.block_shape.append(k)
513513
else:
@@ -521,7 +521,7 @@ def create(
521521
if size != 1:
522522
env = CompileEnvironment.current()
523523
rdim = env.allocate_reduction_dimension(size)
524-
res.offsets.append(state.codegen.offset_var(rdim.block_size_idx))
524+
res.offsets.append(state.codegen.offset_var(rdim.block_id))
525525
res.block_shape.append(rdim.var)
526526
else:
527527
res.offsets.append("0")

helion/_compiler/roll_reduction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def is_reduction(self, node: torch.fx.Node) -> bool:
6060
return (
6161
node.op == "call_function"
6262
and isinstance(lowering := node.meta["lowering"], ReductionLowering)
63-
and lowering.block_index == self.rdim.block_size_idx
63+
and lowering.block_index == self.rdim.block_id
6464
)
6565

6666
def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:
@@ -103,7 +103,7 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:
103103
if isinstance(val, torch.Tensor):
104104
for size in val.size():
105105
block_idx = TileStrategy.get_block_index(size)
106-
num_rdims += block_idx == self.rdim.block_size_idx
106+
num_rdims += block_idx == self.rdim.block_id
107107
if num_rdims > 1:
108108
raise NotImplementedError(
109109
"multiple reduction dims of same size not supported"
@@ -121,7 +121,7 @@ def size_node(self, meta: dict[str, object]) -> torch.fx.Node:
121121
return self._size_node
122122
self._size_node = node = self.outer_graph.call_function(
123123
_get_symnode,
124-
(f"rdim{self.rdim.block_size_idx}",),
124+
(f"rdim{self.rdim.block_id}",),
125125
{},
126126
)
127127
node.meta.update(meta)
@@ -143,7 +143,7 @@ def start_new_graph(self) -> None:
143143
graph.output([*outputs.values()])
144144
graph_id = self.device_ir.add_reduction_loop_graph(
145145
graph,
146-
block_index=self.rdim.block_size_idx,
146+
block_index=self.rdim.block_id,
147147
node_args=self.inner_args,
148148
)
149149
self.graphs_added.append(graph_id)

helion/_compiler/tile_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _add_loop_strategy(
9797

9898
def _add_reduction_strategies(self, fn: DeviceFunction, config: Config) -> None:
9999
env = CompileEnvironment.current()
100-
rdims = [bs.block_size_idx for bs in env.block_sizes if bs.reduction]
100+
rdims = [bs.block_id for bs in env.block_sizes if bs.reduction]
101101
for block_id in rdims:
102102
reduction_loop = env.config_spec.reduction_loops.config_get(
103103
config.reduction_loops, block_id, None

helion/_compiler/tile_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def get_block_index(cls, size: int | torch.SymInt | sympy.Expr) -> int | None:
127127
origin_info.origin,
128128
BlockSizeOrigin,
129129
):
130-
return origin_info.origin.block_size_idx
130+
return origin_info.origin.block_id
131131
return None
132132

133133

helion/_compiler/type_propagation.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
468468
output_sizes.append(1)
469469
elif isinstance(k, TileIndexType):
470470
inputs_consumed += 1
471-
output_sizes.append(env.block_sizes[k.block_size_idx].var)
471+
output_sizes.append(env.block_sizes[k.block_id].var)
472472
elif isinstance(k, TypeNotAllowedOnDevice):
473473
raise exc.TypePropagationError(k)
474474
elif isinstance(k, TensorType) and k.fake_value.ndim == 1:
@@ -944,22 +944,22 @@ def _get_hint(numel: int | torch.SymInt | AutoSize | None) -> int:
944944

945945

946946
class TileIndexType(TypeInfo):
947-
block_size_idx: int
947+
block_id: int
948948

949949
def __str__(self) -> str:
950-
return f"{type(self).__name__}({self.block_size_idx})"
950+
return f"{type(self).__name__}({self.block_id})"
951951

952-
def __init__(self, origin: Origin, block_size_idx: int) -> None:
952+
def __init__(self, origin: Origin, block_id: int) -> None:
953953
super().__init__(origin)
954-
self.block_size_idx = block_size_idx
954+
self.block_id = block_id
955955

956956
def proxy(self) -> object:
957957
with proxy_tensor.disable_proxy_modes_tracing():
958958
fake_mode = torch._C._unset_dispatch_mode(
959959
torch._C._TorchDispatchModeKey.FAKE
960960
)
961961
try:
962-
return TileIndexProxy(self.block_size_idx)
962+
return TileIndexProxy(self.block_id)
963963
finally:
964964
assert fake_mode is not None
965965
torch._C._set_dispatch_mode(fake_mode)
@@ -992,10 +992,10 @@ def allocate_fixed(
992992

993993
def merge(self, other: TypeInfo) -> TypeInfo:
994994
if isinstance(other, TileIndexType):
995-
if self.block_size_idx == other.block_size_idx:
995+
if self.block_id == other.block_id:
996996
return self
997997
return UnknownType(
998-
debug_msg=f"TileIndexType mismatch in control flow: {self.block_size_idx} and {other.block_size_idx}",
998+
debug_msg=f"TileIndexType mismatch in control flow: {self.block_id} and {other.block_id}",
999999
origin=other.origin,
10001000
)
10011001
return super().merge(other)
@@ -1007,17 +1007,17 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
10071007

10081008

10091009
class GridIndexType(SymIntType):
1010-
block_size_idx: int
1010+
block_id: int
10111011

1012-
def __init__(self, origin: Origin, block_size_idx: int) -> None:
1012+
def __init__(self, origin: Origin, block_id: int) -> None:
10131013
from .._compiler.compile_environment import CompileEnvironment
10141014

10151015
env = CompileEnvironment.current()
1016-
super().__init__(origin, env.block_sizes[block_size_idx].var)
1017-
self.block_size_idx = block_size_idx
1016+
super().__init__(origin, env.block_sizes[block_id].var)
1017+
self.block_id = block_id
10181018

10191019
def __str__(self) -> str: # pragma: no cover – debug helper
1020-
return f"{type(self).__name__}({self.block_size_idx})"
1020+
return f"{type(self).__name__}({self.block_id})"
10211021

10221022
@staticmethod
10231023
def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
@@ -1030,10 +1030,10 @@ def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
10301030

10311031
def merge(self, other: TypeInfo) -> TypeInfo: # type: ignore[override]
10321032
if isinstance(other, GridIndexType):
1033-
if self.block_size_idx == other.block_size_idx:
1033+
if self.block_id == other.block_id:
10341034
return self
10351035
return UnknownType(
1036-
debug_msg=f"GridIndexType mismatch in control flow: {self.block_size_idx} vs {other.block_size_idx}",
1036+
debug_msg=f"GridIndexType mismatch in control flow: {self.block_id} vs {other.block_id}",
10371037
origin=other.origin,
10381038
)
10391039
return super().merge(other)

helion/_compiler/variable_origin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class DeviceOrigin(Origin):
221221

222222
@dataclasses.dataclass
223223
class BlockSizeOrigin(Origin):
224-
block_size_idx: int
224+
block_id: int
225225

226226
def host_str(self) -> str:
227227
"""
@@ -232,7 +232,7 @@ def host_str(self) -> str:
232232
from .device_function import DeviceFunction
233233

234234
# Look up the block size variable name; if not set (e.g., size==1), use literal 1
235-
var = DeviceFunction.current().block_size_var(self.block_size_idx)
235+
var = DeviceFunction.current().block_size_var(self.block_id)
236236
if var is None:
237237
return "1"
238238
return var

0 commit comments

Comments
 (0)