Skip to content

Commit 92523e3

Browse files
committed
feat: dynamic shape support for tan, sinh, cosh, asin and acos
1 parent 072c34e commit 92523e3

File tree

6 files changed

+257
-5
lines changed

6 files changed

+257
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1501,7 +1501,7 @@ def aten_ops_cos(
15011501
)
15021502

15031503

1504-
@dynamo_tensorrt_converter(torch.ops.aten.tan.default)
1504+
@dynamo_tensorrt_converter(torch.ops.aten.tan.default, supports_dynamic_shapes=True)
15051505
def aten_ops_tan(
15061506
ctx: ConversionContext,
15071507
target: Target,
@@ -1518,7 +1518,7 @@ def aten_ops_tan(
15181518
)
15191519

15201520

1521-
@dynamo_tensorrt_converter(torch.ops.aten.sinh.default)
1521+
@dynamo_tensorrt_converter(torch.ops.aten.sinh.default, supports_dynamic_shapes=True)
15221522
def aten_ops_sinh(
15231523
ctx: ConversionContext,
15241524
target: Target,
@@ -1535,7 +1535,7 @@ def aten_ops_sinh(
15351535
)
15361536

15371537

1538-
@dynamo_tensorrt_converter(torch.ops.aten.cosh.default)
1538+
@dynamo_tensorrt_converter(torch.ops.aten.cosh.default, supports_dynamic_shapes=True)
15391539
def aten_ops_cosh(
15401540
ctx: ConversionContext,
15411541
target: Target,
@@ -1552,7 +1552,7 @@ def aten_ops_cosh(
15521552
)
15531553

15541554

1555-
@dynamo_tensorrt_converter(torch.ops.aten.asin.default)
1555+
@dynamo_tensorrt_converter(torch.ops.aten.asin.default, supports_dynamic_shapes=True)
15561556
def aten_ops_asin(
15571557
ctx: ConversionContext,
15581558
target: Target,
@@ -1569,7 +1569,7 @@ def aten_ops_asin(
15691569
)
15701570

15711571

1572-
@dynamo_tensorrt_converter(torch.ops.aten.acos.default)
1572+
@dynamo_tensorrt_converter(torch.ops.aten.acos.default, supports_dynamic_shapes=True)
15731573
def aten_ops_acos(
15741574
ctx: ConversionContext,
15751575
target: Target,

tests/py/dynamo/conversion/test_acos_aten.py

+51
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -44,6 +45,56 @@ def forward(self, input):
4445
inputs,
4546
)
4647

48+
@parameterized.expand(
49+
[
50+
(
51+
"3d_dim_dtype_int32",
52+
(3, 2, 1),
53+
(3, 2, 3),
54+
(3, 3, 4),
55+
torch.int32,
56+
torch.float32,
57+
),
58+
(
59+
"2d_dim_dtype_float16",
60+
(1, 1),
61+
(2, 2),
62+
(4, 4),
63+
torch.float16,
64+
torch.float16,
65+
),
66+
(
67+
"3d_dim_dtype_float",
68+
(1, 1, 1),
69+
(1, 2, 3),
70+
(3, 3, 3),
71+
torch.float,
72+
torch.float,
73+
),
74+
]
75+
)
76+
def test_acos_dynamic_shape(
77+
self, _, min_shape, opt_shape, max_shape, type, output_type
78+
):
79+
class acos(nn.Module):
80+
def forward(self, input):
81+
return torch.ops.aten.acos.default(input)
82+
83+
input_specs = [
84+
Input(
85+
min_shape=min_shape,
86+
opt_shape=opt_shape,
87+
max_shape=max_shape,
88+
dtype=type,
89+
),
90+
]
91+
92+
self.run_test_with_dynamic_shape(
93+
acos(),
94+
input_specs,
95+
output_dtypes=[output_type],
96+
)
97+
4798

4899
if __name__ == "__main__":
49100
run_tests()

tests/py/dynamo/conversion/test_asin_aten.py

+50
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -44,6 +45,55 @@ def forward(self, input):
4445
inputs,
4546
)
4647

48+
@parameterized.expand(
49+
[
50+
(
51+
"3d_dim_dtype_int32",
52+
(3, 2, 1),
53+
(3, 2, 3),
54+
(3, 3, 4),
55+
torch.int32,
56+
torch.float32,
57+
),
58+
(
59+
"2d_dim_dtype_float16",
60+
(1, 1),
61+
(2, 2),
62+
(4, 4),
63+
torch.float16,
64+
torch.float16,
65+
),
66+
(
67+
"3d_dim_dtype_float",
68+
(1, 1, 1),
69+
(1, 2, 3),
70+
(3, 3, 3),
71+
torch.float,
72+
torch.float,
73+
),
74+
]
75+
)
76+
def test_asin_dynamic_shape(
77+
self, _, min_shape, opt_shape, max_shape, type, output_type
78+
):
79+
class asin(nn.Module):
80+
def forward(self, input):
81+
return torch.ops.aten.asin.default(input)
82+
83+
input_specs = [
84+
Input(
85+
min_shape=min_shape,
86+
opt_shape=opt_shape,
87+
max_shape=max_shape,
88+
dtype=type,
89+
),
90+
]
91+
self.run_test_with_dynamic_shape(
92+
asin(),
93+
input_specs,
94+
output_dtypes=[output_type],
95+
)
96+
4797

4898
if __name__ == "__main__":
4999
run_tests()

tests/py/dynamo/conversion/test_cosh_aten.py

+50
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -44,6 +45,55 @@ def forward(self, input):
4445
inputs,
4546
)
4647

48+
@parameterized.expand(
49+
[
50+
(
51+
"3d_dim_dtype_int32",
52+
(3, 2, 1),
53+
(3, 2, 3),
54+
(3, 3, 4),
55+
torch.int32,
56+
torch.float32,
57+
),
58+
(
59+
"2d_dim_dtype_float16",
60+
(1, 1),
61+
(2, 2),
62+
(4, 4),
63+
torch.float16,
64+
torch.float16,
65+
),
66+
(
67+
"3d_dim_dtype_float",
68+
(1, 1, 1),
69+
(1, 2, 3),
70+
(3, 3, 3),
71+
torch.float,
72+
torch.float,
73+
),
74+
]
75+
)
76+
def test_cosh_dynamic_shape(
77+
self, _, min_shape, opt_shape, max_shape, type, output_type
78+
):
79+
class cosh(nn.Module):
80+
def forward(self, input):
81+
return torch.ops.aten.cosh.default(input)
82+
83+
input_specs = [
84+
Input(
85+
min_shape=min_shape,
86+
opt_shape=opt_shape,
87+
max_shape=max_shape,
88+
dtype=type,
89+
),
90+
]
91+
self.run_test_with_dynamic_shape(
92+
cosh(),
93+
input_specs,
94+
output_dtypes=[output_type],
95+
)
96+
4797

4898
if __name__ == "__main__":
4999
run_tests()

tests/py/dynamo/conversion/test_sinh_aten.py

+50
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -44,6 +45,55 @@ def forward(self, input):
4445
inputs,
4546
)
4647

48+
@parameterized.expand(
49+
[
50+
(
51+
"3d_dim_dtype_int32",
52+
(3, 2, 1),
53+
(3, 2, 3),
54+
(3, 3, 4),
55+
torch.int32,
56+
torch.float32,
57+
),
58+
(
59+
"2d_dim_dtype_float16",
60+
(1, 1),
61+
(2, 2),
62+
(4, 4),
63+
torch.float16,
64+
torch.float16,
65+
),
66+
(
67+
"3d_dim_dtype_float",
68+
(1, 1, 1),
69+
(1, 2, 3),
70+
(3, 3, 3),
71+
torch.float,
72+
torch.float,
73+
),
74+
]
75+
)
76+
def test_sinh_dynamic_shape(
77+
self, _, min_shape, opt_shape, max_shape, type, output_type
78+
):
79+
class sinh(nn.Module):
80+
def forward(self, input):
81+
return torch.ops.aten.sinh.default(input)
82+
83+
input_specs = [
84+
Input(
85+
min_shape=min_shape,
86+
opt_shape=opt_shape,
87+
max_shape=max_shape,
88+
dtype=type,
89+
),
90+
]
91+
self.run_test_with_dynamic_shape(
92+
sinh(),
93+
input_specs,
94+
output_dtypes=[output_type],
95+
)
96+
4797

4898
if __name__ == "__main__":
4999
run_tests()

tests/py/dynamo/conversion/test_tan_aten.py

+51
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -44,6 +45,56 @@ def forward(self, input):
4445
inputs,
4546
)
4647

48+
@parameterized.expand(
49+
[
50+
(
51+
"3d_dim_dtype_int32",
52+
(3, 2, 1),
53+
(3, 2, 3),
54+
(3, 3, 4),
55+
torch.int32,
56+
torch.float32,
57+
),
58+
(
59+
"2d_dim_dtype_float16",
60+
(1, 1),
61+
(2, 2),
62+
(4, 4),
63+
torch.float16,
64+
torch.float16,
65+
),
66+
(
67+
"3d_dim_dtype_float",
68+
(1, 1, 1),
69+
(1, 2, 3),
70+
(3, 3, 3),
71+
torch.float,
72+
torch.float,
73+
),
74+
]
75+
)
76+
def test_tan_dynamic_shape(
77+
self, _, min_shape, opt_shape, max_shape, type, output_type
78+
):
79+
class tan(nn.Module):
80+
def forward(self, input):
81+
return torch.ops.aten.tan.default(input)
82+
83+
input_specs = [
84+
Input(
85+
min_shape=min_shape,
86+
opt_shape=opt_shape,
87+
max_shape=max_shape,
88+
dtype=type,
89+
),
90+
]
91+
92+
self.run_test_with_dynamic_shape(
93+
tan(),
94+
input_specs,
95+
output_dtypes=[output_type],
96+
)
97+
4798

4899
if __name__ == "__main__":
49100
run_tests()

0 commit comments

Comments
 (0)