Skip to content

Commit 4bd1fb6

Browse files
committed
update
1 parent a993212 commit 4bd1fb6

File tree

5 files changed

+4
-15
lines changed

5 files changed

+4
-15
lines changed

helion/_compiler/compile_environment.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
6969
collections.Counter()
7070
)
7171

72-
self.symint_to_grid_index_type: dict[str, GridIndexType] = {}
73-
7472
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
7573
self.kernel_tensor_sizes[(*map(_to_sympy, sizes),)] += 1
7674

helion/_compiler/indexing_strategy.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
from .. import exc
1313
from .ast_extension import expr_from_string
14-
from .compile_environment import CompileEnvironment
14+
from .compile_environment import CompileEnvironment, GridBlockSizeSource
1515
from .host_function import HostFunction
1616
from .tile_strategy import TileStrategy
17+
from .tile_strategy import GridTileStrategy
1718
from .variable_origin import BlockSizeOrigin
1819

1920
if TYPE_CHECKING:
@@ -169,10 +170,7 @@ def compute_shape(
169170
if isinstance(symbol, sympy.Symbol):
170171
origin = HostFunction.current().symbol_to_origin.get(symbol.name)
171172
if origin and isinstance(origin.origin, BlockSizeOrigin):
172-
if (
173-
str(k)
174-
in CompileEnvironment.current().symint_to_grid_index_type
175-
):
173+
if isinstance(CompileEnvironment.current().block_sizes[origin.origin.block_size_idx].block_size_source, GridBlockSizeSource):
176174
pass
177175
elif tensor.size(tensor.ndim - len(input_size) - 1) != 1:
178176
output_size.append(k)
@@ -217,7 +215,7 @@ def create(
217215
if isinstance(symbol, sympy.Symbol):
218216
origin = HostFunction.current().symbol_to_origin.get(symbol.name)
219217
if origin and isinstance(origin.origin, BlockSizeOrigin):
220-
if str(k) in CompileEnvironment.current().symint_to_grid_index_type:
218+
if isinstance(CompileEnvironment.current().block_sizes[origin.origin.block_size_idx].block_size_source, GridBlockSizeSource):
221219
first_non_grid_index = n + 1
222220
expand = tile_strategy.expand_str(output_size, output_idx)
223221
else:

helion/_compiler/tile_dispatch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def _add_loop_strategy(
6161
env = CompileEnvironment.current()
6262
block_size_infos = [env.block_sizes[i] for i in block_indices]
6363
loop_order = block_size_infos[0].get_order(config, len(block_size_infos))
64-
print(f"block_size_infos: {block_size_infos}, block_size_infos[0]: {block_size_infos[0]}")
6564
if isinstance(block_size_infos[0].block_size_source, GridBlockSizeSource):
6665
strategy: TileStrategy = GridTileStrategy(
6766
fn,

helion/_compiler/tile_strategy.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,6 @@ def __init__(
145145
)
146146
self.block_size = block_size
147147
self.loop_order = loop_order
148-
if str(block_size) == "[1, 1]":
149-
import traceback
150-
traceback.print_stack()
151-
print(f"block_size: {block_size}, loop_order: {loop_order}, block_indices: {block_indices}")
152148

153149
def _reorder(self, block_indices: list[_T]) -> list[_T]:
154150
if len(block_indices) <= 1:
@@ -363,7 +359,6 @@ def __init__(
363359
loop_order: list[int],
364360
l2_grouping: int,
365361
) -> None:
366-
print(f"will create new NDTileStrategy")
367362
assert isinstance(block_size, list)
368363
super().__init__(fn, block_indices, block_size, loop_order)
369364
self.mask_vars: dict[int, str | None] = {}

helion/_compiler/type_propagation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,6 @@ def __init__(self, origin: Origin, block_size_idx: int) -> None:
984984

985985
env = CompileEnvironment.current()
986986
super().__init__(origin, env.block_sizes[block_size_idx].var)
987-
env.symint_to_grid_index_type[str(self.value)] = self
988987
self.block_size_idx = block_size_idx
989988

990989
def __str__(self) -> str: # pragma: no cover – debug helper

0 commit comments

Comments
 (0)