Skip to content

Commit 950b791

Browse files
authored
feat: support aten.resize_ converter (#2874)
1 parent cbdad29 commit 950b791

File tree

5 files changed

+337
-7
lines changed

5 files changed

+337
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+23
Original file line numberDiff line numberDiff line change
@@ -2756,6 +2756,29 @@ def aten_ops_pixel_unshuffle(
27562756
)
27572757

27582758

2759+
@dynamo_tensorrt_converter(torch.ops.aten.resize_.default)
2760+
@enforce_tensor_types(
2761+
{
2762+
0: (TRTTensor,),
2763+
}
2764+
)
2765+
def aten_ops_resize(
2766+
ctx: ConversionContext,
2767+
target: Target,
2768+
args: Tuple[Argument, ...],
2769+
kwargs: Dict[str, Argument],
2770+
name: str,
2771+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2772+
return impl.shuffle.resize(
2773+
ctx,
2774+
target,
2775+
SourceIR.ATEN,
2776+
name,
2777+
input=args[0],
2778+
sizes=args[1],
2779+
)
2780+
2781+
27592782
@enforce_tensor_types({0: (TRTTensor,)})
27602783
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
27612784
def aten_ops_argmax(

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from typing import Optional, Sequence, Union
22

3+
import numpy as np
4+
import tensorrt as trt
35
import torch_tensorrt.dynamo.conversion.impl as impl
46
from torch.fx.node import Target
57
from torch_tensorrt import _enums
68
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
79
from torch_tensorrt.dynamo.conversion.converter_utils import (
810
SourceIR,
911
cast_trt_tensor,
12+
flatten_dims,
1013
get_trt_tensor,
14+
set_layer_name,
1115
)
12-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
16+
from torch_tensorrt.dynamo.utils import Frameworks, unified_dtype_converter
1317
from torch_tensorrt.fx.types import TRTTensor
1418

1519

@@ -131,3 +135,61 @@ def pixel_unshuffle(
131135
permuted_tensor,
132136
shape[:-3] + (out_channels, out_height, out_width),
133137
)
138+
139+
140+
def resize(
141+
ctx: ConversionContext,
142+
target: Union[Target, str],
143+
source_ir: Optional[SourceIR],
144+
name: str,
145+
input: TRTTensor,
146+
sizes: Sequence[int],
147+
) -> TRTTensor:
148+
input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY)
149+
input_val = get_trt_tensor(ctx, input, f"{name}_input")
150+
151+
# Calculate the total number of elements for new and current shape
152+
new_num_elements = np.prod(sizes)
153+
current_num_elements = np.prod(input_val.shape)
154+
155+
if new_num_elements > current_num_elements:
156+
# Create a padding tensor with the required size and initialize new elements with zeros
157+
padding_size = new_num_elements - current_num_elements
158+
padding_tensor = ctx.net.add_constant(
159+
(padding_size,), trt.Weights(np.zeros(padding_size, dtype=input_np_dtype))
160+
).get_output(0)
161+
162+
# Flatten input tensor to 1D for concatenation
163+
flatten_shape = flatten_dims(input_val, 0, -1)
164+
flattened_input = reshape(
165+
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
166+
)
167+
168+
# Concatenate the flattened input tensor and padding tensor
169+
reshaped_tensor = impl.cat.cat(
170+
ctx,
171+
target,
172+
source_ir,
173+
f"{name}_cat",
174+
[flattened_input, padding_tensor],
175+
dim=0,
176+
)
177+
elif new_num_elements < current_num_elements:
178+
# Flatten input tensor to 1D for slicing
179+
flatten_shape = flatten_dims(input_val, 0, -1)
180+
flattened_input = reshape(
181+
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
182+
)
183+
184+
# Slice the flattened input tensor to the desired number of elements
185+
slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1])
186+
reshaped_tensor = slice_layer.get_output(0)
187+
else:
188+
reshaped_tensor = input_val
189+
190+
# Reshape the final output tensor to the target sizes
191+
resized_output = reshape(
192+
ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes
193+
)
194+
195+
return resized_output

py/torch_tensorrt/dynamo/utils.py

+85
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import logging
44
from dataclasses import fields, replace
5+
from enum import Enum
56
from typing import Any, Callable, Dict, Optional, Sequence, Union
67

8+
import numpy as np
9+
import tensorrt as trt
710
import torch
811
from torch_tensorrt._Device import Device
912
from torch_tensorrt._enums import dtype
@@ -13,12 +16,63 @@
1316

1417
from packaging import version
1518

19+
from .types import TRTDataType
20+
1621
logger = logging.getLogger(__name__)
1722

1823
COSINE_THRESHOLD = 0.99
1924
DYNAMIC_DIM = -1
2025

2126

27+
class Frameworks(Enum):
28+
NUMPY = "numpy"
29+
TORCH = "torch"
30+
TRT = "trt"
31+
32+
33+
DataTypeEquivalence: Dict[
34+
TRTDataType, Dict[Frameworks, Union[TRTDataType, np.dtype, torch.dtype]]
35+
] = {
36+
trt.int8: {
37+
Frameworks.NUMPY: np.int8,
38+
Frameworks.TORCH: torch.int8,
39+
Frameworks.TRT: trt.int8,
40+
},
41+
trt.int32: {
42+
Frameworks.NUMPY: np.int32,
43+
Frameworks.TORCH: torch.int32,
44+
Frameworks.TRT: trt.int32,
45+
},
46+
trt.int64: {
47+
Frameworks.NUMPY: np.int64,
48+
Frameworks.TORCH: torch.int64,
49+
Frameworks.TRT: trt.int64,
50+
},
51+
trt.float16: {
52+
Frameworks.NUMPY: np.float16,
53+
Frameworks.TORCH: torch.float16,
54+
Frameworks.TRT: trt.float16,
55+
},
56+
trt.float32: {
57+
Frameworks.NUMPY: np.float32,
58+
Frameworks.TORCH: torch.float32,
59+
Frameworks.TRT: trt.float32,
60+
},
61+
trt.bool: {
62+
Frameworks.NUMPY: bool,
63+
Frameworks.TORCH: torch.bool,
64+
Frameworks.TRT: trt.bool,
65+
},
66+
}
67+
68+
if trt.__version__ >= "7.0":
69+
DataTypeEquivalence[trt.bool] = {
70+
Frameworks.NUMPY: np.bool_,
71+
Frameworks.TORCH: torch.bool,
72+
Frameworks.TRT: trt.bool,
73+
}
74+
75+
2276
def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:
2377
"""Parses a user-provided input argument regarding Python runtime
2478
@@ -317,3 +371,34 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any:
317371
return function_wrapper
318372

319373
return nested_decorator
374+
375+
376+
def unified_dtype_converter(
377+
dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks
378+
) -> Union[np.dtype, torch.dtype, TRTDataType]:
379+
"""
380+
Convert TensorRT, Numpy, or Torch data types to any other of those data types.
381+
382+
Args:
383+
dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
384+
to (Frameworks): The framework to convert the data type to.
385+
386+
Returns:
387+
The equivalent data type in the requested framework.
388+
"""
389+
assert to in Frameworks, f"Expected valid Framework for translation, got {to}"
390+
trt_major_version = int(trt.__version__.split(".")[0])
391+
if dtype in (np.int8, torch.int8, trt.int8):
392+
return DataTypeEquivalence[trt.int8][to]
393+
elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool):
394+
return DataTypeEquivalence[trt.bool][to]
395+
elif dtype in (np.int32, torch.int32, trt.int32):
396+
return DataTypeEquivalence[trt.int32][to]
397+
elif dtype in (np.int64, torch.int64, trt.int64):
398+
return DataTypeEquivalence[trt.int64][to]
399+
elif dtype in (np.float16, torch.float16, trt.float16):
400+
return DataTypeEquivalence[trt.float16][to]
401+
elif dtype in (np.float32, torch.float32, trt.float32):
402+
return DataTypeEquivalence[trt.float32][to]
403+
else:
404+
raise TypeError("%s is not a supported dtype" % dtype)

tests/py/dynamo/conversion/harness.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,20 @@ def run_test_custom_compare_results(
157157
res_trt = trt_mod(*cuda_inputs).cpu()
158158
res_cpu = mod(*cuda_inputs).cpu()
159159
assert len(res_trt) == len(res_cpu)
160-
for output_trt, output_cpu, comparator in zip(
161-
res_trt, res_cpu, comparators
162-
):
163-
comp_func = comparator[0]
164-
args = comparator[1]
165-
self.assertTrue(comp_func(output_trt, output_cpu, *args))
160+
comparator = comparators
161+
162+
if len(cuda_inputs) == 1:
163+
for comparator in comparators:
164+
comp_func = comparator[0]
165+
args = comparator[1]
166+
self.assertTrue(comp_func(res_trt, res_cpu, *args))
167+
else:
168+
for output_trt, output_cpu, comparator in zip(
169+
res_trt, res_cpu, comparators
170+
):
171+
comp_func = comparator[0]
172+
args = comparator[1]
173+
self.assertTrue(comp_func(output_trt, output_cpu, *args))
166174

167175
def run_test_with_error(self, mod, inputs, interpreter, expect_error):
168176
with self.assertRaises(expect_error):

0 commit comments

Comments
 (0)