1
- from typing import Optional
1
+ from typing import Optional , Union
2
2
3
3
import numpy as np
4
4
import tensorrt as trt
11
11
get_trt_tensor ,
12
12
)
13
13
from torch_tensorrt .dynamo .conversion .impl .slice import expand
14
- from torch_tensorrt .fx .converters .converter_utils import broadcast , set_layer_name
14
+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
15
15
from torch_tensorrt .fx .types import TRTTensor
16
16
17
17
@@ -20,23 +20,13 @@ def where(
20
20
target : Target ,
21
21
source_ir : Optional [SourceIR ],
22
22
name : str ,
23
- input : TRTTensor ,
24
- other : TRTTensor ,
25
- condition : TRTTensor ,
23
+ input : Union [ TRTTensor , np . ndarray , torch . Tensor ] ,
24
+ other : Union [ TRTTensor , np . ndarray , torch . Tensor ] ,
25
+ condition : Union [ TRTTensor , np . ndarray , torch . Tensor ] ,
26
26
) -> TRTTensor :
27
27
if not (broadcastable (input , other )):
28
28
assert "The two torch tensors should be broadcastable"
29
29
30
- # get output shape
31
- # purpose of this is to bring input and other rank same as
32
- # output_shape to input it to the add_expand operation
33
- # condition will have dimension of either input or other
34
- input , other = broadcast (ctx .net , input , other , f"{ name } _x" , f"{ name } _y" )
35
- if len (tuple (condition .shape )) != len (tuple (input .shape )):
36
- condition , input = broadcast (
37
- ctx .net , condition , input , f"{ name } _condition" , f"{ name } _x"
38
- )
39
-
40
30
x_shape = list (input .shape )
41
31
y_shape = list (other .shape )
42
32
condition_shape = list (condition .shape )
@@ -71,7 +61,11 @@ def where(
71
61
if isinstance (input , torch .Tensor )
72
62
else np .expand_dims (input , axis = 0 )
73
63
)
74
- input = input .expand (output_shape )
64
+ input = (
65
+ input .expand (output_shape )
66
+ if isinstance (input , torch .Tensor )
67
+ else np .broadcast_to (input , output_shape )
68
+ )
75
69
x_val = get_trt_tensor (ctx , input , f"{ name } _x" )
76
70
else :
77
71
x_val = input
@@ -89,7 +83,11 @@ def where(
89
83
if isinstance (other , torch .Tensor )
90
84
else np .expand_dims (other , axis = 0 )
91
85
)
92
- other = other .expand (output_shape )
86
+ other = (
87
+ other .expand (output_shape )
88
+ if isinstance (other , torch .Tensor )
89
+ else np .broadcast_to (other , output_shape )
90
+ )
93
91
y_val = get_trt_tensor (ctx , other , f"{ name } _y" )
94
92
else :
95
93
y_val = other
0 commit comments