Skip to content

Commit 732de3b

Browse files
committed
fix dim bugs and update tests
1 parent 754c3b5 commit 732de3b

File tree

7 files changed

+14
-7
lines changed

7 files changed

+14
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def aten_ops_amax(
668668
SourceIR.ATEN,
669669
name,
670670
args[0],
671-
args[1],
671+
args_bounds_check(args, 1, replacement=[]),
672672
args_bounds_check(args, 2, replacement=False),
673673
)
674674

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def amax(
2727
):
2828
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
2929

30+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
31+
dim = tuple(range(len(input_val.shape)))
32+
3033
layer = ctx.net.add_reduce(
3134
input_val,
3235
trt.ReduceOperation.MAX,
@@ -51,7 +54,7 @@ def sum(
5154
):
5255
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
5356

54-
if dim is None:
57+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
5558
dim = tuple(range(len(input_val.shape)))
5659

5760
layer = ctx.net.add_reduce(
@@ -169,7 +172,7 @@ def mean(
169172
):
170173
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
171174

172-
if dim is None:
175+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
173176
dim = tuple(range(len(input_val.shape)))
174177

175178
layer = ctx.net.add_reduce(

tests/py/dynamo/conversion/test_amax_aten.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def forward(self, x):
2929

3030
@parameterized.expand(
3131
[
32+
((1, 2, 4), [], True),
3233
((3, 2, 4), [1], True),
3334
((2, 1, 4, 5), [0, 3], True),
3435
((2, 3, 4, 5), [0, 1, 2, 3], False),
@@ -69,6 +70,7 @@ def forward(self, x):
6970

7071
@parameterized.expand(
7172
[
73+
((1, 2, 4), [], True, torch.int, 0, 5),
7274
((3, 2, 4), [1], True, torch.int, 0, 5),
7375
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
7476
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),

tests/py/dynamo/conversion/test_max_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class TestMaxConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12+
((1, 2),),
1213
((3, 2, 4),),
1314
((2, 3, 4, 5),),
14-
((2, 3, 4, 5),),
1515
((6, 7, 5, 4, 5),),
1616
]
1717
)

tests/py/dynamo/conversion/test_min_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class TestMinConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12+
((1, 2),),
1213
((3, 2, 4),),
1314
((2, 3, 4, 5),),
14-
((2, 3, 4, 5),),
1515
((6, 7, 5, 4, 5),),
1616
]
1717
)

tests/py/dynamo/conversion/test_prod_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class TestProdConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12+
((1, 2),),
1213
((3, 2, 4),),
1314
((2, 3, 4, 5),),
14-
((2, 3, 4, 5),),
1515
((6, 7, 5, 4, 5),),
1616
]
1717
)

tests/py/dynamo/conversion/test_sum_aten.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class TestSumConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12+
((1, 2),),
1213
((3, 2, 4),),
1314
((2, 3, 4, 5),),
14-
((2, 3, 4, 5),),
1515
((6, 7, 5, 4, 5),),
1616
]
1717
)
@@ -49,6 +49,7 @@ def forward(self, x):
4949

5050
@parameterized.expand(
5151
[
52+
((1, 2, 4), [], True),
5253
((3, 2, 4), [1], True),
5354
((2, 1, 4, 5), None, True),
5455
((2, 3, 4, 5), [0, 1, 2, 3], False),
@@ -89,6 +90,7 @@ def forward(self, x):
8990

9091
@parameterized.expand(
9192
[
93+
((1, 2, 4), [], True, torch.int, 0, 5),
9294
((3, 2, 4), [1], True, torch.int, 0, 5),
9395
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
9496
((2, 3, 4, 5), None, False, torch.int32, -5, 0),

0 commit comments

Comments
 (0)