Skip to content

Commit d7c3164

Browse files
abhi-iyernarendasan
authored andcommitted
fix(): added test cases to explicitly check hidden/cell state outputs
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 8c61248 commit d7c3164

File tree

1 file changed

+94
-2
lines changed

1 file changed

+94
-2
lines changed

tests/core/converters/test_lstm_cell.cpp

+94-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7-
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
7+
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckHidden) {
88
const auto graph = R"IR(
99
graph(%0 : Tensor,
1010
%1 : Tensor,
@@ -53,7 +53,56 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
5353
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
5454
}
5555

56-
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
56+
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckCell) {
57+
const auto graph = R"IR(
58+
graph(%0 : Tensor,
59+
%1 : Tensor,
60+
%2 : Tensor,
61+
%3 : Tensor,
62+
%4 : Tensor,
63+
%5 : Tensor,
64+
%6 : Tensor):
65+
%7 : Tensor[] = prim::ListConstruct(%1, %2)
66+
%8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
67+
return (%9))IR";
68+
69+
auto g = std::make_shared<torch::jit::Graph>();
70+
torch::jit::parseIR(graph, &*g);
71+
72+
auto input = at::randn({50, 10}, {at::kCUDA});
73+
auto h0 = at::randn({50, 20}, {at::kCUDA});
74+
auto c0 = at::randn({50, 20}, {at::kCUDA});
75+
auto w_ih = at::randn({4*20, 10}, {at::kCUDA});
76+
auto w_hh = at::randn({4*20, 20}, {at::kCUDA});
77+
auto b_ih = at::randn({4*20}, {at::kCUDA});
78+
auto b_hh = at::randn({4*20}, {at::kCUDA});
79+
80+
auto jit_input = at::clone(input);
81+
auto jit_h0 = at::clone(h0);
82+
auto jit_c0 = at::clone(c0);
83+
auto jit_w_ih = at::clone(w_ih);
84+
auto jit_w_hh = at::clone(w_hh);
85+
auto jit_b_ih = at::clone(b_ih);
86+
auto jit_b_hh = at::clone(b_hh);
87+
88+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
89+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh, jit_b_ih, jit_b_hh});
90+
91+
auto trt_input = at::clone(input);
92+
auto trt_h0 = at::clone(h0);
93+
auto trt_c0 = at::clone(c0);
94+
auto trt_w_ih = at::clone(w_ih);
95+
auto trt_w_hh = at::clone(w_hh);
96+
auto trt_b_ih = at::clone(b_ih);
97+
auto trt_b_hh = at::clone(b_hh);
98+
99+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
100+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
101+
102+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
103+
}
104+
105+
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckHidden) {
57106
const auto graph = R"IR(
58107
graph(%0 : Tensor,
59108
%1 : Tensor,
@@ -93,5 +142,48 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
93142
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
94143
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
95144

145+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
146+
}
147+
148+
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckCell) {
149+
const auto graph = R"IR(
150+
graph(%0 : Tensor,
151+
%1 : Tensor,
152+
%2 : Tensor,
153+
%3 : Tensor,
154+
%4 : Tensor):
155+
%5 : None = prim::Constant()
156+
%6 : None = prim::Constant()
157+
%7 : Tensor[] = prim::ListConstruct(%1, %2)
158+
%8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
159+
return (%9))IR";
160+
161+
auto g = std::make_shared<torch::jit::Graph>();
162+
torch::jit::parseIR(graph, &*g);
163+
164+
auto input = at::randn({50, 10}, {at::kCUDA});
165+
auto h0 = at::randn({50, 20}, {at::kCUDA});
166+
auto c0 = at::randn({50, 20}, {at::kCUDA});
167+
auto w_ih = at::randn({4*20, 10}, {at::kCUDA});
168+
auto w_hh = at::randn({4*20, 20}, {at::kCUDA});
169+
170+
auto jit_input = at::clone(input);
171+
auto jit_h0 = at::clone(h0);
172+
auto jit_c0 = at::clone(c0);
173+
auto jit_w_ih = at::clone(w_ih);
174+
auto jit_w_hh = at::clone(w_hh);
175+
176+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
177+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh});
178+
179+
auto trt_input = at::clone(input);
180+
auto trt_h0 = at::clone(h0);
181+
auto trt_c0 = at::clone(c0);
182+
auto trt_w_ih = at::clone(w_ih);
183+
auto trt_w_hh = at::clone(w_hh);
184+
185+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
186+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
187+
96188
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
97189
}

0 commit comments

Comments
 (0)