Skip to content

Commit 8c61248

Browse files
abhi-iyernarendasan
authored andcommitted
feat(//core/conversion/converters): LSTMCell converter
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent a3e1093 commit 8c61248

File tree

3 files changed

+110
-5
lines changed

3 files changed

+110
-5
lines changed

core/conversion/converters/impl/lstm_cell.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ nvinfer1::ITensor* add_bias(nvinfer1::ITensor* a, nvinfer1::ITensor* b, std::str
2828
auto shuffle = ctx->net->addShuffle(*b);
2929
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
3030
shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_dim), a_dim.nbDims));
31+
3132
b = shuffle->getOutput(0);
3233
}
3334

35+
LOG_DEBUG(b_name << "'s shape: " << b->getDimensions());
36+
3437
auto add = ctx->net->addElementWise(*a, *b, nvinfer1::ElementWiseOperation::kSUM);
3538
TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n);
3639

@@ -72,14 +75,14 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
7275
TRTORCH_CHECK(mm1, "Unable to create matrix multiplication node: " << *n);
7376
auto mm1_out = mm1->getOutput(0);
7477

75-
auto out1 = !args[4].IValue()->isNone() ? add_bias(mm1_out, args[4].ITensorOrFreeze(ctx), "b_ih", ctx, n) : mm1_out;
78+
auto out1 = (args[4].isIValue() && args[4].IValue()->isNone()) ? mm1_out : add_bias(mm1_out, args[4].ITensorOrFreeze(ctx), "b_ih", ctx, n);
7679

7780
// calculate second half of gates
7881
auto mm2 = ctx->net->addMatrixMultiply(*state[0], nvinfer1::MatrixOperation::kNONE, *w_hh, nvinfer1::MatrixOperation::kTRANSPOSE);
7982
TRTORCH_CHECK(mm2, "Unable to create matrix multiplication node: " << *n);
8083
auto mm2_out = mm2->getOutput(0);
8184

82-
auto out2 = !args[5].IValue()->isNone() ? add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n) : mm2_out;
85+
auto out2 = (args[5].isIValue() && args[5].IValue()->isNone()) ? mm2_out : add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n);
8386

8487
// gates
8588
auto add = ctx->net->addElementWise(*out1, *out2, nvinfer1::ElementWiseOperation::kSUM);
@@ -130,7 +133,7 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
130133
TRTORCH_CHECK(forget_cx, "Unable to create ElementWise layer from node: " << *n);
131134
auto in_cell = ctx->net->addElementWise(*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD);
132135
TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n);
133-
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
136+
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);
134137
TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n);
135138
auto cy_out = ctx->AssociateValueAndTensor(n->outputs()[1], cy->getOutput(0));
136139

@@ -143,7 +146,7 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
143146

144147
LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions());
145148
LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions());
146-
149+
147150
return true;
148151
}
149152
});

tests/core/converters/BUILD

+6-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ converter_test(
6767
name = "test_stack"
6868
)
6969

70+
converter_test(
71+
name = "test_lstm_cell"
72+
)
73+
7074
test_suite(
7175
name = "test_converters",
7276
tests = [
@@ -83,6 +87,7 @@ test_suite(
8387
":test_unary",
8488
":test_interpolate",
8589
":test_select",
86-
":test_stack"
90+
":test_stack",
91+
":test_lstm_cell"
8792
]
8893
)
+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor,
10+
%1 : Tensor,
11+
%2 : Tensor,
12+
%3 : Tensor,
13+
%4 : Tensor,
14+
%5 : Tensor,
15+
%6 : Tensor):
16+
%7 : Tensor[] = prim::ListConstruct(%1, %2)
17+
%8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
18+
return (%8))IR";
19+
20+
auto g = std::make_shared<torch::jit::Graph>();
21+
torch::jit::parseIR(graph, &*g);
22+
23+
auto input = at::randn({50, 10}, {at::kCUDA});
24+
auto h0 = at::randn({50, 20}, {at::kCUDA});
25+
auto c0 = at::randn({50, 20}, {at::kCUDA});
26+
auto w_ih = at::randn({4*20, 10}, {at::kCUDA});
27+
auto w_hh = at::randn({4*20, 20}, {at::kCUDA});
28+
auto b_ih = at::randn({4*20}, {at::kCUDA});
29+
auto b_hh = at::randn({4*20}, {at::kCUDA});
30+
31+
auto jit_input = at::clone(input);
32+
auto jit_h0 = at::clone(h0);
33+
auto jit_c0 = at::clone(c0);
34+
auto jit_w_ih = at::clone(w_ih);
35+
auto jit_w_hh = at::clone(w_hh);
36+
auto jit_b_ih = at::clone(b_ih);
37+
auto jit_b_hh = at::clone(b_hh);
38+
39+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
40+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh, jit_b_ih, jit_b_hh});
41+
42+
auto trt_input = at::clone(input);
43+
auto trt_h0 = at::clone(h0);
44+
auto trt_c0 = at::clone(c0);
45+
auto trt_w_ih = at::clone(w_ih);
46+
auto trt_w_hh = at::clone(w_hh);
47+
auto trt_b_ih = at::clone(b_ih);
48+
auto trt_b_hh = at::clone(b_hh);
49+
50+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
51+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
52+
53+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
54+
}
55+
56+
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
57+
const auto graph = R"IR(
58+
graph(%0 : Tensor,
59+
%1 : Tensor,
60+
%2 : Tensor,
61+
%3 : Tensor,
62+
%4 : Tensor):
63+
%5 : None = prim::Constant()
64+
%6 : None = prim::Constant()
65+
%7 : Tensor[] = prim::ListConstruct(%1, %2)
66+
%8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
67+
return (%8))IR";
68+
69+
auto g = std::make_shared<torch::jit::Graph>();
70+
torch::jit::parseIR(graph, &*g);
71+
72+
auto input = at::randn({50, 10}, {at::kCUDA});
73+
auto h0 = at::randn({50, 20}, {at::kCUDA});
74+
auto c0 = at::randn({50, 20}, {at::kCUDA});
75+
auto w_ih = at::randn({4*20, 10}, {at::kCUDA});
76+
auto w_hh = at::randn({4*20, 20}, {at::kCUDA});
77+
78+
auto jit_input = at::clone(input);
79+
auto jit_h0 = at::clone(h0);
80+
auto jit_c0 = at::clone(c0);
81+
auto jit_w_ih = at::clone(w_ih);
82+
auto jit_w_hh = at::clone(w_hh);
83+
84+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
85+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh});
86+
87+
auto trt_input = at::clone(input);
88+
auto trt_h0 = at::clone(h0);
89+
auto trt_c0 = at::clone(c0);
90+
auto trt_w_ih = at::clone(w_ih);
91+
auto trt_w_hh = at::clone(w_hh);
92+
93+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
94+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
95+
96+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
97+
}

0 commit comments

Comments
 (0)