Skip to content

Commit 406d860

Browse files
committed
feat: Add converter files for reflection pad 1d and 2d
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 8e89578 commit 406d860

File tree

2 files changed

+201
-0
lines changed

2 files changed

+201
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#include <ATen/ATen.h>
2+
#include <vector>
3+
#include "NvInfer.h"
4+
#include "core/conversion/converters/converters.h"
5+
#include "core/util/prelude.h"
6+
#include "torch/torch.h"
7+
8+
namespace torch_tensorrt {
9+
namespace core {
10+
namespace conversion {
11+
namespace converters {
12+
namespace impl {
13+
namespace {
14+
15+
auto reflection_padXd TORCHTRT_UNUSED =
16+
RegisterNodeConversionPatterns()
17+
.pattern({"aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor)",
18+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19+
auto in = args[0].ITensor();
20+
auto inDims = in->getDimensions();
21+
int64_t inRank = inDims.nbDims;
22+
auto padding = args[1].unwrapToIntList().vec();
23+
if (padding.size() == 1) {
24+
for (int64_t i = 0; i < 3; i++)
25+
padding.push_back(padding[0]);
26+
}
27+
if (inRank == 4) {
28+
TORCHTRT_CHECK(padding.size() == 4, "4D tensors expect 4 values for padding");
29+
} else {
30+
TORCHTRT_THROW_ERROR("Only 4D padding are supported for now");
31+
}
32+
33+
std::vector<nvinfer1::ITensor*> tensors_vec;
34+
// 2d padding: (padding_left, padding_right, padding_top, padding_bottom)
35+
36+
for (int64_t i = 0; i < int(padding.size() / 2); i++) {
37+
int64_t axis = inRank - (i + 1); // axis = {inRank - 1, inRank - 2}
38+
int64_t padding_index = i * 2;
39+
40+
if (padding[padding_index] > 0) { // left/top padding value
41+
tensors_vec.clear();
42+
43+
for (int i = 0; i < padding[padding_index]; i++) {
44+
at::Tensor left_indices = torch::tensor({padding[padding_index] - i}, torch::kInt32);
45+
auto indicesTensor = tensor_to_const(ctx, left_indices);
46+
auto left_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
47+
auto left_gather_out = left_gather_layer->getOutput(0);
48+
tensors_vec.push_back(left_gather_out);
49+
}
50+
tensors_vec.push_back(in);
51+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
52+
concat_layer->setAxis(axis);
53+
in = concat_layer->getOutput(0);
54+
inDims = in->getDimensions();
55+
}
56+
57+
if (padding[padding_index + 1] > 0) { // right/bottom padding value
58+
tensors_vec.clear();
59+
tensors_vec.push_back(in);
60+
61+
for (int i = 0; i < padding[padding_index + 1]; i++) {
62+
nvinfer1::ITensor* indicesTensor = NULL;
63+
auto indices = torch::tensor({inDims.d[axis] - 1 - (i + 1)}, torch::kInt32);
64+
indicesTensor = tensor_to_const(ctx, indices);
65+
auto right_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
66+
auto right_gather_out = right_gather_layer->getOutput(0);
67+
tensors_vec.push_back(right_gather_out);
68+
}
69+
70+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
71+
concat_layer->setAxis(axis);
72+
in = concat_layer->getOutput(0);
73+
inDims = in->getDimensions();
74+
}
75+
}
76+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
77+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
78+
79+
return true;
80+
}})
81+
.pattern({"aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor)",
82+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
83+
auto in = args[0].ITensor();
84+
auto inDims = in->getDimensions();
85+
int64_t inRank = inDims.nbDims;
86+
auto padding = args[1].unwrapToIntList().vec();
87+
if (padding.size() == 1) {
88+
for (int64_t i = 0; i < 1; i++)
89+
padding.push_back(padding[0]);
90+
}
91+
92+
std::vector<nvinfer1::ITensor*> tensors_vec;
93+
// 1d padding: (padding_left, padding_right)
94+
95+
int64_t axis = inRank - 1;
96+
int64_t padding_index = 0;
97+
98+
if (padding[padding_index] > 0) { // left padding value
99+
tensors_vec.clear();
100+
101+
for (int i = 0; i < padding[padding_index]; i++) {
102+
at::Tensor left_indices = torch::tensor({padding[padding_index] - i}, torch::kInt32);
103+
auto indicesTensor = tensor_to_const(ctx, left_indices);
104+
auto left_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
105+
auto left_gather_out = left_gather_layer->getOutput(0);
106+
tensors_vec.push_back(left_gather_out);
107+
}
108+
tensors_vec.push_back(in);
109+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
110+
concat_layer->setAxis(axis);
111+
in = concat_layer->getOutput(0);
112+
inDims = in->getDimensions();
113+
}
114+
115+
if (padding[padding_index + 1] > 0) { // right padding value
116+
tensors_vec.clear();
117+
tensors_vec.push_back(in);
118+
119+
for (int i = 0; i < padding[padding_index + 1]; i++) {
120+
nvinfer1::ITensor* indicesTensor = NULL;
121+
auto indices = torch::tensor({inDims.d[axis] - 1 - (i + 1)}, torch::kInt32);
122+
indicesTensor = tensor_to_const(ctx, indices);
123+
auto right_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
124+
auto right_gather_out = right_gather_layer->getOutput(0);
125+
tensors_vec.push_back(right_gather_out);
126+
}
127+
128+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
129+
concat_layer->setAxis(axis);
130+
in = concat_layer->getOutput(0);
131+
inDims = in->getDimensions();
132+
}
133+
134+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
135+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
136+
137+
return true;
138+
}});
139+
140+
} // namespace
141+
} // namespace impl
142+
} // namespace converters
143+
} // namespace conversion
144+
} // namespace core
145+
} // namespace torch_tensorrt
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include <iostream>
2+
#include <string>
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenReflection_pad2dTensorConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%0 : Tensor):
11+
%1 : int = prim::Constant[value=1]()
12+
%2 : int = prim::Constant[value=1]()
13+
%3 : int = prim::Constant[value=2]()
14+
%4 : int = prim::Constant[value=0]()
15+
%5 : int[] = prim::ListConstruct(%1, %2, %3, %4)
16+
%6 : Tensor = aten::reflection_pad2d(%0, %5)
17+
return (%6))IR";
18+
19+
auto g = std::make_shared<torch::jit::Graph>();
20+
torch::jit::parseIR(graph, g.get());
21+
22+
auto in1 = at::randint(1, 10, {1, 3, 5, 5}, {at::kCUDA});
23+
24+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
25+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});
26+
27+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
28+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
29+
30+
ASSERT_TRUE(
31+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
32+
}
33+
34+
TEST(Converters, ATenReflection_pad1dTensorConvertsCorrectly) {
35+
const auto graph = R"IR(
36+
graph(%0 : Tensor):
37+
%1 : int = prim::Constant[value=1]()
38+
%2 : int = prim::Constant[value=2]()
39+
%3 : int[] = prim::ListConstruct(%1, %2)
40+
%4 : Tensor = aten::reflection_pad1d(%0, %3)
41+
return (%4))IR";
42+
43+
auto g = std::make_shared<torch::jit::Graph>();
44+
torch::jit::parseIR(graph, g.get());
45+
46+
auto in1 = at::randint(1, 10, {1, 2, 4}, {at::kCUDA});
47+
48+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
49+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});
50+
51+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
52+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
53+
54+
ASSERT_TRUE(
55+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
56+
}

0 commit comments

Comments
 (0)