Skip to content

Commit c145af0

Browse files
authored
Improve error message for overpacked tiles (#126)
1 parent 931ea4d commit c145af0

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

helion/_compiler/type_propagation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,8 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
474474
elif isinstance(k, TensorType) and k.fake_value.ndim == 1:
475475
inputs_consumed += 1
476476
output_sizes.append(k.fake_value.size(0))
477+
elif k.contains_type(TileIndexType):
478+
raise exc.OverpackedTile(k)
477479
else:
478480
raise exc.InvalidIndexingType(k)
479481
if inputs_consumed != self.fake_value.ndim:

helion/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ class FailedToUnpackTile(BaseError):
125125
)
126126

127127

128+
class OverpackedTile(BaseError):
129+
message = (
130+
"Got a tile wrapped inside a container when indexing a tensor: {0!s}\n"
131+
"Did you mix up `hl.tile([x])` and `hl.tile(x)`?"
132+
)
133+
134+
128135
class AssignmentMultipleTargets(NotAllowedOnDevice):
129136
message = "Assignment with multiple targets (a=b=1) is not allowed inside the `hl.tile` or `hl.grid` loop."
130137

test/test_errors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,15 @@ def sum_kernel(x: torch.Tensor) -> torch.Tensor:
2323

2424
with self.assertRaises(helion.exc.FailedToUnpackTile):
2525
code_and_output(sum_kernel, (torch.randn(2, 3, 4, device=DEVICE),))
26+
27+
def test_tile_overpacking(self):
28+
@helion.kernel()
29+
def fn(x: torch.Tensor) -> torch.Tensor:
30+
batch = x.size(0)
31+
out = x.new_empty(batch)
32+
for tile_wrapped_in_tuple in hl.tile([batch]):
33+
out[tile_wrapped_in_tuple] = x[tile_wrapped_in_tuple, :].sum(1)
34+
return out
35+
36+
with self.assertRaises(helion.exc.OverpackedTile):
37+
code_and_output(fn, (torch.randn(100, 100, device=DEVICE),))

0 commit comments

Comments
 (0)