Skip to content

Commit eaecfb2

Browse files
committed
chore: minor naming issues
1 parent b606306 commit eaecfb2

File tree

3 files changed

+13
-25
lines changed

3 files changed

+13
-25
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,7 +2610,6 @@ def aten_ops_pixel_unshuffle(
26102610
)
26112611

26122612

2613-
@dynamo_tensorrt_converter(torch.ops.aten.resize.default)
26142613
@dynamo_tensorrt_converter(torch.ops.aten.resize_.default)
26152614
@enforce_tensor_types(
26162615
{
@@ -2624,7 +2623,7 @@ def aten_ops_resize(
26242623
kwargs: Dict[str, Argument],
26252624
name: str,
26262625
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2627-
return impl.shuffle.resize_(
2626+
return impl.shuffle.resize(
26282627
ctx,
26292628
target,
26302629
SourceIR.ATEN,

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,15 @@ def pixel_unshuffle(
132132
)
133133

134134

135-
def resize_(
135+
def resize(
136136
ctx: ConversionContext,
137137
target: Union[Target, str],
138138
source_ir: Optional[SourceIR],
139139
name: str,
140140
input: TRTTensor,
141141
sizes: Sequence[int],
142142
) -> TRTTensor:
143-
144143
input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY)
145-
146144
input_val = get_trt_tensor(ctx, input, f"{name}_input")
147145

148146
# Calculate the total number of elements for new and current shape
@@ -158,31 +156,34 @@ def resize_(
158156

159157
# Flatten input tensor to 1D for concatenation
160158
flatten_shape = flatten_dims(input_val, 0, -1)
161-
flattened_input = impl.shuffle.reshape(
159+
flattened_input = reshape(
162160
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
163161
)
164162

165163
# Concatenate the flattened input tensor and padding tensor
166-
concat_layer = ctx.net.add_concatenation([flattened_input, padding_tensor])
167-
concat_layer.axis = 0
168-
reshaped_tensor = concat_layer.get_output(0)
169-
164+
reshaped_tensor = impl.cat.cat(
165+
ctx,
166+
target,
167+
source_ir,
168+
f"{name}_cat",
169+
[flattened_input, padding_tensor],
170+
dim=0,
171+
)
170172
elif new_num_elements < current_num_elements:
171173
# Flatten input tensor to 1D for slicing
172174
flatten_shape = flatten_dims(input_val, 0, -1)
173-
flattened_input = impl.shuffle.reshape(
175+
flattened_input = reshape(
174176
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
175177
)
176178

177179
# Slice the flattened input tensor to the desired number of elements
178180
slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1])
179181
reshaped_tensor = slice_layer.get_output(0)
180-
181182
else:
182183
reshaped_tensor = input_val
183184

184185
# Reshape the final output tensor to the target sizes
185-
resized_output = impl.shuffle.reshape(
186+
resized_output = reshape(
186187
ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes
187188
)
188189

tests/py/dynamo/conversion/test_resize_aten.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ class TestResizeConverter(DispatchTestCase):
2020
)
2121
def test_resize_1d_input_float(self, target_shape):
2222
class Resize(torch.nn.Module):
23-
def __init__(self):
24-
super().__init__()
25-
2623
def forward(self, x):
2724
return torch.ops.aten.resize_.default(x, target_shape)
2825

@@ -46,9 +43,6 @@ def forward(self, x):
4643
)
4744
def test_resize_1d_input_int(self, target_shape):
4845
class Resize(torch.nn.Module):
49-
def __init__(self):
50-
super().__init__()
51-
5246
def forward(self, x):
5347
return torch.ops.aten.resize_.default(x, target_shape)
5448

@@ -73,9 +67,6 @@ def forward(self, x):
7367
)
7468
def test_resize_2d_input_float(self, target_shape):
7569
class Resize(torch.nn.Module):
76-
def __init__(self):
77-
super().__init__()
78-
7970
def forward(self, x):
8071
return torch.ops.aten.resize_.default(x, target_shape)
8172

@@ -100,9 +91,6 @@ def forward(self, x):
10091
)
10192
def test_resize_2d_input_int(self, target_shape):
10293
class Resize(torch.nn.Module):
103-
def __init__(self):
104-
super().__init__()
105-
10694
def forward(self, x):
10795
return torch.ops.aten.resize_.default(x, target_shape)
10896

0 commit comments

Comments
 (0)