4
4
#include " tests/util/util.h"
5
5
#include " core/compiler.h"
6
6
7
- TEST (Converters, ATenLSTMCellConvertsCorrectlyWithBias ) {
7
+ TEST (Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckHidden ) {
8
8
const auto graph = R"IR(
9
9
graph(%0 : Tensor,
10
10
%1 : Tensor,
@@ -53,7 +53,56 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
53
53
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
54
54
}
55
55
56
- TEST (Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
56
+ TEST (Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckCell) {
57
+ const auto graph = R"IR(
58
+ graph(%0 : Tensor,
59
+ %1 : Tensor,
60
+ %2 : Tensor,
61
+ %3 : Tensor,
62
+ %4 : Tensor,
63
+ %5 : Tensor,
64
+ %6 : Tensor):
65
+ %7 : Tensor[] = prim::ListConstruct(%1, %2)
66
+ %8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
67
+ return (%9))IR" ;
68
+
69
+ auto g = std::make_shared<torch::jit::Graph>();
70
+ torch::jit::parseIR (graph, &*g);
71
+
72
+ auto input = at::randn ({50 , 10 }, {at::kCUDA });
73
+ auto h0 = at::randn ({50 , 20 }, {at::kCUDA });
74
+ auto c0 = at::randn ({50 , 20 }, {at::kCUDA });
75
+ auto w_ih = at::randn ({4 *20 , 10 }, {at::kCUDA });
76
+ auto w_hh = at::randn ({4 *20 , 20 }, {at::kCUDA });
77
+ auto b_ih = at::randn ({4 *20 }, {at::kCUDA });
78
+ auto b_hh = at::randn ({4 *20 }, {at::kCUDA });
79
+
80
+ auto jit_input = at::clone (input);
81
+ auto jit_h0 = at::clone (h0);
82
+ auto jit_c0 = at::clone (c0);
83
+ auto jit_w_ih = at::clone (w_ih);
84
+ auto jit_w_hh = at::clone (w_hh);
85
+ auto jit_b_ih = at::clone (b_ih);
86
+ auto jit_b_hh = at::clone (b_hh);
87
+
88
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
89
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh, jit_b_ih, jit_b_hh});
90
+
91
+ auto trt_input = at::clone (input);
92
+ auto trt_h0 = at::clone (h0);
93
+ auto trt_c0 = at::clone (c0);
94
+ auto trt_w_ih = at::clone (w_ih);
95
+ auto trt_w_hh = at::clone (w_hh);
96
+ auto trt_b_ih = at::clone (b_ih);
97
+ auto trt_b_hh = at::clone (b_hh);
98
+
99
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
100
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
101
+
102
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
103
+ }
104
+
105
+ TEST (Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckHidden) {
57
106
const auto graph = R"IR(
58
107
graph(%0 : Tensor,
59
108
%1 : Tensor,
@@ -93,5 +142,48 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
93
142
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
94
143
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
95
144
145
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
146
+ }
147
+
148
+ TEST (Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckCell) {
149
+ const auto graph = R"IR(
150
+ graph(%0 : Tensor,
151
+ %1 : Tensor,
152
+ %2 : Tensor,
153
+ %3 : Tensor,
154
+ %4 : Tensor):
155
+ %5 : None = prim::Constant()
156
+ %6 : None = prim::Constant()
157
+ %7 : Tensor[] = prim::ListConstruct(%1, %2)
158
+ %8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
159
+ return (%9))IR" ;
160
+
161
+ auto g = std::make_shared<torch::jit::Graph>();
162
+ torch::jit::parseIR (graph, &*g);
163
+
164
+ auto input = at::randn ({50 , 10 }, {at::kCUDA });
165
+ auto h0 = at::randn ({50 , 20 }, {at::kCUDA });
166
+ auto c0 = at::randn ({50 , 20 }, {at::kCUDA });
167
+ auto w_ih = at::randn ({4 *20 , 10 }, {at::kCUDA });
168
+ auto w_hh = at::randn ({4 *20 , 20 }, {at::kCUDA });
169
+
170
+ auto jit_input = at::clone (input);
171
+ auto jit_h0 = at::clone (h0);
172
+ auto jit_c0 = at::clone (c0);
173
+ auto jit_w_ih = at::clone (w_ih);
174
+ auto jit_w_hh = at::clone (w_hh);
175
+
176
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
177
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh});
178
+
179
+ auto trt_input = at::clone (input);
180
+ auto trt_h0 = at::clone (h0);
181
+ auto trt_c0 = at::clone (c0);
182
+ auto trt_w_ih = at::clone (w_ih);
183
+ auto trt_w_hh = at::clone (w_hh);
184
+
185
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
186
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
187
+
96
188
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
97
189
}
0 commit comments