Skip to content

Commit fd740c1

Browse files
authored
[Error Message] Update block config size length mismatch (#139)
1 parent 4c0ad72 commit fd740c1

File tree

4 files changed

+109
-5
lines changed

4 files changed

+109
-5
lines changed

helion/_compiler/type_propagation.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,11 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
479479
else:
480480
raise exc.InvalidIndexingType(k)
481481
if inputs_consumed != self.fake_value.ndim:
482-
raise exc.RankMismatch(self.fake_value.ndim, inputs_consumed)
482+
raise exc.RankMismatch(
483+
self.fake_value.ndim,
484+
inputs_consumed,
485+
f"tensor shape: {tuple(self.fake_value.shape)}",
486+
)
483487
return output_sizes
484488

485489
def propagate_setitem(
@@ -488,11 +492,16 @@ def propagate_setitem(
488492
if origin.is_host():
489493
warning(exc.TensorOperationInWrapper)
490494
else:
491-
lhs_rank = len(self._device_indexing_size(key))
495+
lhs_shape = self._device_indexing_size(key)
496+
lhs_rank = len(lhs_shape)
492497
if isinstance(value, TensorType):
493498
rhs_rank = value.fake_value.ndim
494499
if lhs_rank != rhs_rank:
495-
raise exc.RankMismatch(lhs_rank, rhs_rank)
500+
raise exc.RankMismatch(
501+
lhs_rank,
502+
rhs_rank,
503+
f"LHS shape: {tuple(lhs_shape)}, RHS shape: {tuple(value.fake_value.shape)}",
504+
)
496505
elif isinstance(value, UnknownType):
497506
raise exc.TypePropagationError(value)
498507
else:

helion/autotuner/block_id_sequence.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def _normalize(
159159
values.append(spec._fill_missing())
160160
except NotImplementedError:
161161
raise InvalidConfig(
162-
f"Not enough values for config[{name!r}], expected {size}, got {len(values)}"
162+
f"Not enough values for config[{name!r}]: expected {size} block sizes "
163+
f"(one for each tiled dimension), got {len(values)}. "
164+
f"Did you forget to specify block sizes for all your hl.tile() dimensions?"
163165
) from None
164166
for i, spec in enumerate(self._data):
165167
values[i] = spec._normalize(f"config[{name}][{i}]", values[i])

helion/exc.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,26 @@ class NestedGridLoop(BaseError):
8282

8383

8484
class RankMismatch(BaseError):
85-
message = "Expected ndim={0}, but got ndim={1}"
85+
message = "Expected ndim={expected_ndim}, but got ndim={actual_ndim}{shape_part}. You have {direction}."
86+
87+
def __init__(
88+
self, expected_ndim: int, actual_ndim: int, shape_info: str = ""
89+
) -> None:
90+
if actual_ndim > expected_ndim:
91+
direction = "too many indices"
92+
elif actual_ndim < expected_ndim:
93+
direction = "too few indices"
94+
else:
95+
direction = "indices that don't match expected structure"
96+
97+
shape_part = f" ({shape_info})" if shape_info else ""
98+
99+
super().__init__(
100+
expected_ndim=expected_ndim,
101+
actual_ndim=actual_ndim,
102+
shape_part=shape_part,
103+
direction=direction,
104+
)
86105

87106

88107
class InvalidIndexingType(BaseError):

test/test_errors.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,77 @@ def fn(x: torch.Tensor) -> torch.Tensor:
3535

3636
with self.assertRaises(helion.exc.OverpackedTile):
3737
code_and_output(fn, (torch.randn(100, 100, device=DEVICE),))
38+
39+
def test_invalid_config_insufficient_block_sizes(self):
40+
"""Test that InvalidConfig shows helpful message for missing block sizes."""
41+
42+
@helion.kernel(config=helion.Config(block_sizes=[32, 64]))
43+
def fn(x: torch.Tensor) -> torch.Tensor:
44+
batch, seq_len, hidden = x.size()
45+
out = torch.empty_like(x)
46+
for tile_batch, tile_seq, tile_hidden in hl.tile([batch, seq_len, hidden]):
47+
out[tile_batch, tile_seq, tile_hidden] = x[
48+
tile_batch, tile_seq, tile_hidden
49+
]
50+
return out
51+
52+
with self.assertRaisesRegex(
53+
helion.exc.InvalidConfig,
54+
r"Not enough values for config.*expected 3 block sizes.*got 2.*"
55+
r"Did you forget to specify block sizes for all your hl\.tile\(\) dimensions\?",
56+
):
57+
code_and_output(
58+
fn,
59+
(torch.randn(4, 8, 16, device=DEVICE),),
60+
)
61+
62+
def test_rank_mismatch_assignment(self):
63+
"""Test that RankMismatch shows tensor shapes in assignment errors."""
64+
65+
@helion.kernel()
66+
def fn(x: torch.Tensor) -> torch.Tensor:
67+
batch, seq_len = x.size()
68+
out = x.new_empty(batch, seq_len)
69+
for tile_batch, tile_seq in hl.tile([batch, seq_len]):
70+
scalar_val = x[tile_batch, 0].sum() # Creates 0D tensor
71+
out[tile_batch, tile_seq] = scalar_val # 0D -> 2D assignment
72+
return out
73+
74+
with self.assertRaisesRegex(
75+
helion.exc.RankMismatch,
76+
r"Expected ndim=2, but got ndim=0.*You have too few indices",
77+
):
78+
code_and_output(fn, (torch.randn(4, 8, device=DEVICE),))
79+
80+
def test_rank_mismatch_indexing(self):
81+
"""Test that RankMismatch shows tensor shapes in indexing errors."""
82+
83+
@helion.kernel()
84+
def fn(x: torch.Tensor) -> torch.Tensor:
85+
batch = x.size(0)
86+
out = x.new_empty(batch)
87+
for tile_batch in hl.tile([batch]):
88+
scalar_val = x[tile_batch].sum() # 1d index for 2d tensor
89+
out = scalar_val
90+
return out
91+
92+
with self.assertRaisesRegex(
93+
helion.exc.RankMismatch,
94+
r"Expected ndim=2, but got ndim=1.*You have too few indices",
95+
):
96+
code_and_output(fn, (torch.randn(4, 8, device=DEVICE),))
97+
98+
def test_rank_mismatch_indexing_too_many(self):
99+
@helion.kernel()
100+
def fn(x: torch.Tensor) -> torch.Tensor:
101+
batch = x.size(0)
102+
fill = x.new_empty(batch, batch)
103+
for tile_batch in hl.tile(batch):
104+
fill = x[tile_batch, tile_batch] # 2d index for 1d tensor
105+
return fill
106+
107+
with self.assertRaisesRegex(
108+
helion.exc.RankMismatch,
109+
r"Expected ndim=1, but got ndim=2.*You have too many indices",
110+
):
111+
code_and_output(fn, (torch.randn(8, device=DEVICE),))

0 commit comments

Comments
 (0)