Skip to content

Commit 50ab2c1

Browse files
authored
fix: Repair aten.where with Numpy + Broadcast (#2372)
1 parent 0e4c5d8 commit 50ab2c1

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

33
import numpy as np
44
import tensorrt as trt
@@ -11,7 +11,7 @@
1111
get_trt_tensor,
1212
)
1313
from torch_tensorrt.dynamo.conversion.impl.slice import expand
14-
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
14+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1515
from torch_tensorrt.fx.types import TRTTensor
1616

1717

@@ -20,23 +20,13 @@ def where(
2020
target: Target,
2121
source_ir: Optional[SourceIR],
2222
name: str,
23-
input: TRTTensor,
24-
other: TRTTensor,
25-
condition: TRTTensor,
23+
input: Union[TRTTensor, np.ndarray, torch.Tensor],
24+
other: Union[TRTTensor, np.ndarray, torch.Tensor],
25+
condition: Union[TRTTensor, np.ndarray, torch.Tensor],
2626
) -> TRTTensor:
2727
if not (broadcastable(input, other)):
2828
assert "The two torch tensors should be broadcastable"
2929

30-
# get output shape
31-
# purpose of this is to bring input and other rank same as
32-
# output_shape to input it to the add_expand operation
33-
# condition will have dimension of either input or other
34-
input, other = broadcast(ctx.net, input, other, f"{name}_x", f"{name}_y")
35-
if len(tuple(condition.shape)) != len(tuple(input.shape)):
36-
condition, input = broadcast(
37-
ctx.net, condition, input, f"{name}_condition", f"{name}_x"
38-
)
39-
4030
x_shape = list(input.shape)
4131
y_shape = list(other.shape)
4232
condition_shape = list(condition.shape)
@@ -71,7 +61,11 @@ def where(
7161
if isinstance(input, torch.Tensor)
7262
else np.expand_dims(input, axis=0)
7363
)
74-
input = input.expand(output_shape)
64+
input = (
65+
input.expand(output_shape)
66+
if isinstance(input, torch.Tensor)
67+
else np.broadcast_to(input, output_shape)
68+
)
7569
x_val = get_trt_tensor(ctx, input, f"{name}_x")
7670
else:
7771
x_val = input
@@ -89,7 +83,11 @@ def where(
8983
if isinstance(other, torch.Tensor)
9084
else np.expand_dims(other, axis=0)
9185
)
92-
other = other.expand(output_shape)
86+
other = (
87+
other.expand(output_shape)
88+
if isinstance(other, torch.Tensor)
89+
else np.broadcast_to(other, output_shape)
90+
)
9391
y_val = get_trt_tensor(ctx, other, f"{name}_y")
9492
else:
9593
y_val = other

tests/py/dynamo/conversion/test_where_aten.py

+17
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ def forward(self, condition):
5959
(condition,),
6060
)
6161

62+
def test_const_input_with_broadcast(self):
63+
class Where(nn.Module):
64+
def __init__(self, *args, **kwargs) -> None:
65+
super().__init__(*args, **kwargs)
66+
self.inputY = torch.randn((1,))
67+
self.inputX = torch.randn((1,))
68+
69+
def forward(self, condition):
70+
return torch.ops.aten.where.self(condition, self.inputX, self.inputY)
71+
72+
input1 = torch.randn((5, 6, 7))
73+
condition = input1 < 0
74+
self.run_test(
75+
Where(),
76+
(condition,),
77+
)
78+
6279

6380
if __name__ == "__main__":
6481
run_tests()

0 commit comments

Comments
 (0)