Skip to content

Commit 9305811

Browse files
committed
fix: Refactor data type handling in FX
1 parent e1a8611 commit 9305811

File tree

8 files changed

+177
-112
lines changed

8 files changed

+177
-112
lines changed

py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def forward(self, x):
5454
0,
5555
msg=f"MulInt TRT outputs don't match with the original model.",
5656
)
57+
torch._dynamo.reset()
5758

5859
def test_lowering_add_float(self):
5960
class AddFloat(torch.nn.Module):
@@ -106,6 +107,8 @@ def forward(self, x):
106107
msg=f"AddFloat TRT outputs don't match with the original model.",
107108
)
108109

110+
torch._dynamo.reset()
111+
109112

110113
if __name__ == "__main__":
111114
run_tests()

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS
1717
from .input_tensor_spec import InputTensorSpec
1818
from torch_tensorrt.fx.observer import Observer
19-
from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
19+
from torch_tensorrt.fx.utils import (
20+
get_dynamic_dims,
21+
LowerPrecision,
22+
unified_dtype_converter,
23+
Frameworks,
24+
)
2025

2126
_LOGGER: logging.Logger = logging.getLogger(__name__)
2227

@@ -321,7 +326,9 @@ def placeholder(self, target, args, kwargs):
321326
self.optimization_profiles[i].set_shape(target, *shape_range)
322327

323328
return self.network.add_input(
324-
name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
329+
name=target,
330+
shape=tuple(shape),
331+
dtype=unified_dtype_converter(dtype, Frameworks.TRT),
325332
)
326333

327334
def call_module(self, target, args, kwargs):

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.fx.immutable_collections import immutable_list
1919
from torch.fx.node import Argument, Target
2020

21-
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
21+
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks
2222

2323
from .converter_utils import * # noqa: F403
2424
from torch_tensorrt.fx.passes.lower_basic_pass import (
@@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer(
400400
)
401401

402402
# cast value to TRTensor
403-
dt = torch_dtype_from_trt(input_val.dtype)
403+
dt = unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
404404
value = 0 if value == None else value
405405
value_const = get_trt_tensor(
406406
network, torch.tensor([value], dtype=dt), f"{name}_value"
@@ -1550,7 +1550,7 @@ def acc_ops_to_dtype(
15501550
input_t = get_trt_tensor(network, input_val, f"{name}_input_t")
15511551
if input_dtype:
15521552
if isinstance(input_dtype, torch.dtype):
1553-
input_dtype = torch_dtype_to_trt(input_dtype)
1553+
input_dtype = unified_dtype_converter(input_dtype, Frameworks.TRT)
15541554
input_t = type_cast(network, target, f"{name}_input", input_t, input_dtype)
15551555
return input_t
15561556

@@ -1811,7 +1811,7 @@ def acc_ops_logical_xor(
18111811
# f"isinf received input {input_t} that is not part "
18121812
# "of the TensorRT region!"
18131813
# )
1814-
# tdtype = torch_dtype_from_trt(input_t.dtype)
1814+
# tdtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
18151815

18161816
# inf_t = torch.ones(tuple(input_t.shape))
18171817
# inf_t = inf_t * float("inf")
@@ -1849,7 +1849,7 @@ def acc_ops_any(
18491849

18501850
if input_t.dtype in (trt.float32, trt.float16, trt.int32):
18511851
comp_t = torch.zeros(tuple([*input_t.shape])).to(
1852-
torch_dtype_from_trt(input_t.dtype)
1852+
unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
18531853
)
18541854
comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
18551855
kwargs_new = {"input": input_t, "other": comp_t}
@@ -2738,7 +2738,7 @@ def acc_ops_masked_fill_tensor(
27382738
if type(value_t) is torch.Tensor:
27392739
value_t = value_t.cpu().numpy()
27402740
# cast to input type
2741-
input_dtype = torch_dtype_from_trt(input_t.dtype)
2741+
input_dtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
27422742
value_t = (torch.ones(shape) * value_t).to(input_dtype)
27432743
input_val = get_trt_tensor(network, input_t, f"{name}_input")
27442744
value_val = get_trt_tensor(network, value_t, f"{name}_input")
@@ -2872,7 +2872,11 @@ def add_clamp(network, input, val, op, name):
28722872
# clamping scalar
28732873
acc_ops_clamp_trt = get_trt_tensor(
28742874
network,
2875-
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
2875+
squeeze_left(
2876+
torch.tensor(
2877+
[val], dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH)
2878+
)
2879+
),
28762880
f"{name}_clamp_{val}",
28772881
)
28782882
else:
@@ -2881,7 +2885,8 @@ def add_clamp(network, input, val, op, name):
28812885
(
28822886
val
28832887
* torch.ones(
2884-
acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)
2888+
acc_ops_clamp_shape,
2889+
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
28852890
)
28862891
)
28872892
.cpu()
@@ -3527,7 +3532,9 @@ def acc_ops_cumsum(
35273532
iterator = loop.add_iterator(input_val, dim, False)
35283533
data = iterator.get_output(0)
35293534
new_dims = tuple(data.shape)
3530-
zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
3535+
zero_tensor = torch.zeros(
3536+
new_dims, dtype=unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
3537+
)
35313538
zero_tensor = network.add_constant(
35323539
zero_tensor.shape, to_numpy(zero_tensor)
35333540
).get_output(0)
@@ -3670,7 +3677,7 @@ def acc_ops_new_ones(
36703677
dtype_val = kwargs.get("dtype")
36713678
if dtype_val is None:
36723679
dtype_val = input_val.dtype
3673-
dtype_val = torch_dtype_from_trt(dtype_val)
3680+
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)
36743681

36753682
device_val = kwargs.get("device")
36763683
assert (
@@ -3694,7 +3701,7 @@ def acc_ops_new_empty(
36943701
dtype_val = kwargs.get("dtype")
36953702
if dtype_val is None:
36963703
dtype_val = input_val.dtype
3697-
dtype_val = torch_dtype_from_trt(dtype_val)
3704+
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)
36983705

36993706
device_val = kwargs.get("device")
37003707
assert (

py/torch_tensorrt/fx/converters/aten_ops_converters.py

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from torch.fx.immutable_collections import immutable_list
1919
from torch.fx.node import Argument, Target
2020

21-
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
22-
2321
from .converter_utils import * # noqa: F403
2422
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2523
from torch_tensorrt.fx.converters.impl import activation

0 commit comments

Comments
 (0)