-
Notifications
You must be signed in to change notification settings - Fork 363
/
Copy pathtest_acos_aten.py
100 lines (89 loc) · 2.51 KB
/
test_acos_aten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from .harness import DispatchTestCase
class TestAcosConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), torch.float),
((1, 20), torch.float),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_acos_float(self, input_shape, dtype):
class acos(nn.Module):
def forward(self, input):
return torch.ops.aten.acos.default(input)
inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
acos(),
inputs,
)
@parameterized.expand(
[
((10,), torch.int, 0, 5),
((1, 20), torch.int32, -10, 10),
((2, 3, 4), torch.int, -5, 5),
]
)
def test_acos_int(self, input_shape, dtype, low, high):
class acos(nn.Module):
def forward(self, input):
return torch.ops.aten.acos.default(input)
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
acos(),
inputs,
)
@parameterized.expand(
[
(
"3d_dim_dtype_int32",
(3, 2, 1),
(3, 2, 3),
(3, 3, 4),
torch.int32,
torch.float32,
),
(
"2d_dim_dtype_float16",
(1, 1),
(2, 2),
(4, 4),
torch.float16,
torch.float16,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_acos_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class acos(nn.Module):
def forward(self, input):
return torch.ops.aten.acos.default(input)
input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
acos(),
input_specs,
output_dtypes=[output_type],
)
if __name__ == "__main__":
run_tests()