18
18
from torch .fx .immutable_collections import immutable_list
19
19
from torch .fx .node import Argument , Target
20
20
21
- from ..utils import get_dynamic_dims , torch_dtype_from_trt , torch_dtype_to_trt
21
+ from ..utils import get_dynamic_dims , unified_dtype_converter , Frameworks
22
22
23
23
from .converter_utils import * # noqa: F403
24
24
from torch_tensorrt .fx .passes .lower_basic_pass import (
@@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer(
400
400
)
401
401
402
402
# cast value to TRTensor
403
- dt = torch_dtype_from_trt (input_val .dtype )
403
+ dt = unified_dtype_converter (input_val .dtype , Frameworks . TORCH )
404
404
value = 0 if value == None else value
405
405
value_const = get_trt_tensor (
406
406
network , torch .tensor ([value ], dtype = dt ), f"{ name } _value"
@@ -1550,7 +1550,7 @@ def acc_ops_to_dtype(
1550
1550
input_t = get_trt_tensor (network , input_val , f"{ name } _input_t" )
1551
1551
if input_dtype :
1552
1552
if isinstance (input_dtype , torch .dtype ):
1553
- input_dtype = torch_dtype_to_trt (input_dtype )
1553
+ input_dtype = unified_dtype_converter (input_dtype , Frameworks . TRT )
1554
1554
input_t = type_cast (network , target , f"{ name } _input" , input_t , input_dtype )
1555
1555
return input_t
1556
1556
@@ -1811,7 +1811,7 @@ def acc_ops_logical_xor(
1811
1811
# f"isinf received input {input_t} that is not part "
1812
1812
# "of the TensorRT region!"
1813
1813
# )
1814
- # tdtype = torch_dtype_from_trt (input_t.dtype)
1814
+ # tdtype = unified_dtype_converter (input_t.dtype, Frameworks.TORCH )
1815
1815
1816
1816
# inf_t = torch.ones(tuple(input_t.shape))
1817
1817
# inf_t = inf_t * float("inf")
@@ -1849,7 +1849,7 @@ def acc_ops_any(
1849
1849
1850
1850
if input_t .dtype in (trt .float32 , trt .float16 , trt .int32 ):
1851
1851
comp_t = torch .zeros (tuple ([* input_t .shape ])).to (
1852
- torch_dtype_from_trt (input_t .dtype )
1852
+ unified_dtype_converter (input_t .dtype , Frameworks . TORCH )
1853
1853
)
1854
1854
comp_t = get_trt_tensor (network , comp_t , f"{ name } _comp_t" )
1855
1855
kwargs_new = {"input" : input_t , "other" : comp_t }
@@ -2738,7 +2738,7 @@ def acc_ops_masked_fill_tensor(
2738
2738
if type (value_t ) is torch .Tensor :
2739
2739
value_t = value_t .cpu ().numpy ()
2740
2740
# cast to input type
2741
- input_dtype = torch_dtype_from_trt (input_t .dtype )
2741
+ input_dtype = unified_dtype_converter (input_t .dtype , Frameworks . TORCH )
2742
2742
value_t = (torch .ones (shape ) * value_t ).to (input_dtype )
2743
2743
input_val = get_trt_tensor (network , input_t , f"{ name } _input" )
2744
2744
value_val = get_trt_tensor (network , value_t , f"{ name } _input" )
@@ -2872,7 +2872,11 @@ def add_clamp(network, input, val, op, name):
2872
2872
# clamping scalar
2873
2873
acc_ops_clamp_trt = get_trt_tensor (
2874
2874
network ,
2875
- squeeze_left (torch .tensor ([val ], dtype = torch_dtype_from_trt (input .dtype ))),
2875
+ squeeze_left (
2876
+ torch .tensor (
2877
+ [val ], dtype = unified_dtype_converter (input .dtype , Frameworks .TORCH )
2878
+ )
2879
+ ),
2876
2880
f"{ name } _clamp_{ val } " ,
2877
2881
)
2878
2882
else :
@@ -2881,7 +2885,8 @@ def add_clamp(network, input, val, op, name):
2881
2885
(
2882
2886
val
2883
2887
* torch .ones (
2884
- acc_ops_clamp_shape , dtype = torch_dtype_from_trt (input .dtype )
2888
+ acc_ops_clamp_shape ,
2889
+ dtype = unified_dtype_converter (input .dtype , Frameworks .TORCH ),
2885
2890
)
2886
2891
)
2887
2892
.cpu ()
@@ -3527,7 +3532,9 @@ def acc_ops_cumsum(
3527
3532
iterator = loop .add_iterator (input_val , dim , False )
3528
3533
data = iterator .get_output (0 )
3529
3534
new_dims = tuple (data .shape )
3530
- zero_tensor = torch .zeros (new_dims , dtype = trt_dtype_to_torch_dtype (input_val .dtype ))
3535
+ zero_tensor = torch .zeros (
3536
+ new_dims , dtype = unified_dtype_converter (input_val .dtype , Frameworks .TORCH )
3537
+ )
3531
3538
zero_tensor = network .add_constant (
3532
3539
zero_tensor .shape , to_numpy (zero_tensor )
3533
3540
).get_output (0 )
@@ -3670,7 +3677,7 @@ def acc_ops_new_ones(
3670
3677
dtype_val = kwargs .get ("dtype" )
3671
3678
if dtype_val is None :
3672
3679
dtype_val = input_val .dtype
3673
- dtype_val = torch_dtype_from_trt (dtype_val )
3680
+ dtype_val = unified_dtype_converter (dtype_val , Frameworks . TORCH )
3674
3681
3675
3682
device_val = kwargs .get ("device" )
3676
3683
assert (
@@ -3694,7 +3701,7 @@ def acc_ops_new_empty(
3694
3701
dtype_val = kwargs .get ("dtype" )
3695
3702
if dtype_val is None :
3696
3703
dtype_val = input_val .dtype
3697
- dtype_val = torch_dtype_from_trt (dtype_val )
3704
+ dtype_val = unified_dtype_converter (dtype_val , Frameworks . TORCH )
3698
3705
3699
3706
device_val = kwargs .get ("device" )
3700
3707
assert (
0 commit comments