1
+ #include < string>
2
+ #include " gtest/gtest.h"
3
+ #include " torch/csrc/jit/ir/irparser.h"
4
+ #include " tests/util/util.h"
5
+ #include " core/compiler.h"
6
+
7
+ TEST (Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
8
+ const auto graph = R"IR(
9
+ graph(%0 : Tensor,
10
+ %1 : Tensor,
11
+ %2 : Tensor,
12
+ %3 : Tensor,
13
+ %4 : Tensor,
14
+ %5 : Tensor,
15
+ %6 : Tensor):
16
+ %7 : Tensor[] = prim::ListConstruct(%1, %2)
17
+ %8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
18
+ return (%8))IR" ;
19
+
20
+ auto g = std::make_shared<torch::jit::Graph>();
21
+ torch::jit::parseIR (graph, &*g);
22
+
23
+ auto input = at::randn ({50 , 10 }, {at::kCUDA });
24
+ auto h0 = at::randn ({50 , 20 }, {at::kCUDA });
25
+ auto c0 = at::randn ({50 , 20 }, {at::kCUDA });
26
+ auto w_ih = at::randn ({4 *20 , 10 }, {at::kCUDA });
27
+ auto w_hh = at::randn ({4 *20 , 20 }, {at::kCUDA });
28
+ auto b_ih = at::randn ({4 *20 }, {at::kCUDA });
29
+ auto b_hh = at::randn ({4 *20 }, {at::kCUDA });
30
+
31
+ auto jit_input = at::clone (input);
32
+ auto jit_h0 = at::clone (h0);
33
+ auto jit_c0 = at::clone (c0);
34
+ auto jit_w_ih = at::clone (w_ih);
35
+ auto jit_w_hh = at::clone (w_hh);
36
+ auto jit_b_ih = at::clone (b_ih);
37
+ auto jit_b_hh = at::clone (b_hh);
38
+
39
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
40
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh, jit_b_ih, jit_b_hh});
41
+
42
+ auto trt_input = at::clone (input);
43
+ auto trt_h0 = at::clone (h0);
44
+ auto trt_c0 = at::clone (c0);
45
+ auto trt_w_ih = at::clone (w_ih);
46
+ auto trt_w_hh = at::clone (w_hh);
47
+ auto trt_b_ih = at::clone (b_ih);
48
+ auto trt_b_hh = at::clone (b_hh);
49
+
50
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
51
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
52
+
53
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
54
+ }
55
+
56
+ TEST (Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
57
+ const auto graph = R"IR(
58
+ graph(%0 : Tensor,
59
+ %1 : Tensor,
60
+ %2 : Tensor,
61
+ %3 : Tensor,
62
+ %4 : Tensor):
63
+ %5 : None = prim::Constant()
64
+ %6 : None = prim::Constant()
65
+ %7 : Tensor[] = prim::ListConstruct(%1, %2)
66
+ %8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
67
+ return (%8))IR" ;
68
+
69
+ auto g = std::make_shared<torch::jit::Graph>();
70
+ torch::jit::parseIR (graph, &*g);
71
+
72
+ auto input = at::randn ({50 , 10 }, {at::kCUDA });
73
+ auto h0 = at::randn ({50 , 20 }, {at::kCUDA });
74
+ auto c0 = at::randn ({50 , 20 }, {at::kCUDA });
75
+ auto w_ih = at::randn ({4 *20 , 10 }, {at::kCUDA });
76
+ auto w_hh = at::randn ({4 *20 , 20 }, {at::kCUDA });
77
+
78
+ auto jit_input = at::clone (input);
79
+ auto jit_h0 = at::clone (h0);
80
+ auto jit_c0 = at::clone (c0);
81
+ auto jit_w_ih = at::clone (w_ih);
82
+ auto jit_w_hh = at::clone (w_hh);
83
+
84
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
85
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh});
86
+
87
+ auto trt_input = at::clone (input);
88
+ auto trt_h0 = at::clone (h0);
89
+ auto trt_c0 = at::clone (c0);
90
+ auto trt_w_ih = at::clone (w_ih);
91
+ auto trt_w_hh = at::clone (w_hh);
92
+
93
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
94
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
95
+
96
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
97
+ }
0 commit comments