Skip to content

Commit 8eb2eab

Browse files
authored
fix: Resolve compilation bug for empty tensors in aten::select (#1623)
1 parent 3e422f5 commit 8eb2eab

File tree

6 files changed

+100
-5
lines changed

6 files changed

+100
-5
lines changed

core/conversion/converters/impl/select.cpp

+20-2
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,26 @@ auto select_registrations TORCHTRT_UNUSED =
149149
// IShuffleLayer removes redundant dimensions
150150
auto shuffle_layer = ctx->net->addShuffle(*out);
151151
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
152-
shuffle_layer->setReshapeDimensions(
153-
util::squeezeDims(out->getDimensions(), dim, !ctx->input_is_dynamic));
152+
153+
auto num_zero_dimensions =
154+
util::validateInputDimsForShuffle(out->getDimensions(), ctx->input_is_dynamic);
155+
TORCHTRT_CHECK(
156+
num_zero_dimensions >= 0,
157+
"Detected multiple zero dimensions and dynamic shape in aten::select, "
158+
<< "which is not currently supported in TensorRT");
159+
160+
// If the input is not dynamic, and the tensor is empty (has some dimension 0)
161+
// Then 0 is no longer a placeholder for inherited dimensions
162+
if (!ctx->input_is_dynamic && (num_zero_dimensions > 0)) {
163+
LOG_DEBUG("Setting zero as a true dimension (not placeholder) in aten::select");
164+
shuffle_layer->setZeroIsPlaceholder(false);
165+
}
166+
167+
shuffle_layer->setReshapeDimensions(util::squeezeDims(
168+
out->getDimensions(),
169+
dim,
170+
ctx->input_is_dynamic,
171+
ctx->input_is_dynamic && (num_zero_dimensions > 0)));
154172
shuffle_layer->setName(util::node_info(n).c_str());
155173
out = shuffle_layer->getOutput(0);
156174
}

core/util/trt_util.cpp

+37-2
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,50 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use
180180
return dims;
181181
}
182182

183-
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros) {
183+
int validateInputDimsForShuffle(const nvinfer1::Dims& d, bool input_is_dynamic) {
184+
int num_zeros_detected = 0;
185+
186+
// For each dimension, increment counter if that dimension has value 0
187+
for (int i = 0; i < d.nbDims; i++) {
188+
if (d.d[i] == 0) {
189+
num_zeros_detected++;
190+
}
191+
}
192+
193+
// If the tensor from which the dimensions originate has dynamic shape and more than 1
194+
// zero dimension is detected, this constitutes an invalid shape to the TRT Shuffle Layer,
195+
// since dynamic dimensions to Shuffle Layers are generally represented with a 0
196+
// denoting to inherit the dimension from the input tensor, thus causing an
197+
// overload of the "0" dimension
198+
return (input_is_dynamic && num_zeros_detected > 1) ? -1 : num_zeros_detected;
199+
}
200+
201+
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros, bool swap_existing_zeros) {
184202
// acceptable range for pos is [0, d.nbDims]
185203
TORCHTRT_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to squeeze is out of bounds.");
186204

187205
nvinfer1::Dims dims;
188206
int j = 0;
189207
for (int i = 0; i < d.nbDims; i++) {
190208
if (i != pos) {
191-
dims.d[j++] = (use_zeros && d.d[i] == -1) ? 0 : d.d[i];
209+
// If zeros are replacing dynamic/existing dimensions,
210+
// Replace all instances of -1, indicating dynamic dimension
211+
// with 0, indicating copy the dimension from another tensor
212+
// (Generally used for reshape operations)
213+
if (use_zeros && d.d[i] == -1) {
214+
dims.d[j] = 0;
215+
// If zeros already exist in the dimensions (empty tensor),
216+
// Replace all instances of 0, indicating empty dimension
217+
// with -1, indicating inherit the dimension from reshape
218+
// (Generally used for reshape operations)
219+
} else if (swap_existing_zeros && d.d[i] == 0) {
220+
dims.d[j] = -1;
221+
// Otherwise, replace the dimension with the same value from the input
222+
} else {
223+
dims.d[j] = d.d[i];
224+
}
225+
226+
j++;
192227
}
193228
}
194229
dims.nbDims = j;

core/util/trt_util.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
135135
nvinfer1::Dims toDimsTailPad(c10::IntArrayRef l, uint64_t pad_to);
136136
nvinfer1::Dims toDimsTailPad(c10::List<int64_t> l, uint64_t pad_to);
137137
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
138+
int validateInputDimsForShuffle(const nvinfer1::Dims& d, bool input_is_dynamic);
138139
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val = 1, bool use_zeros = true);
139-
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros = true);
140+
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros = true, bool swap_existing_zeros = false);
140141
nvinfer1::Dims squeezeAllDims(const nvinfer1::Dims& d, bool use_zeros_for_unknown_dims = true);
141142
nvinfer1::Dims toDims(c10::IntArrayRef l);
142143
nvinfer1::Dims toDims(c10::List<int64_t> l);

tests/core/conversion/converters/test_select.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,31 @@ TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) {
140140
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
141141
}
142142

143+
TEST(Converters, ATenSelectEmptyTensorConvertsCorrectly) {
144+
const auto graph = R"IR(
145+
graph(%0 : Tensor):
146+
%2 : int = prim::Constant[value=1]()
147+
%3 : int = prim::Constant[value=0]()
148+
%4 : Tensor = aten::select(%0, %3, %2)
149+
return (%4))IR";
150+
151+
auto g = std::make_shared<torch::jit::Graph>();
152+
153+
torch::jit::parseIR(graph, g.get());
154+
155+
auto in = torch::ones({2, 20, 0, 768}).to(at::kCUDA);
156+
157+
auto jit_in = at::clone(in);
158+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
159+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
160+
161+
auto trt_in = at::clone(in);
162+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
163+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
164+
165+
ASSERT_TRUE(torch_tensorrt::tests::util::sameShape(jit_results[0], trt_results[0]));
166+
}
167+
143168
TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) {
144169
const auto graph = R"IR(
145170
graph(%x.1 : Tensor):

tests/util/util.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor,
2828
return result <= threshold;
2929
}
3030

31+
bool sameShape(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor) {
32+
if (computed_tensor.sizes().size() != gt_tensor.sizes().size()) {
33+
return false;
34+
}
35+
36+
for (size_t i = 0; i < computed_tensor.sizes().size(); i++) {
37+
if (computed_tensor.sizes()[i] != gt_tensor.sizes()[i]) {
38+
return false;
39+
}
40+
}
41+
42+
return true;
43+
}
44+
3145
bool cosineSimEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float threshold) {
3246
torch::Tensor cosine_sim = torch::nn::functional::cosine_similarity(
3347
computed_tensor.flatten(), gt_tensor.flatten(), torch::nn::functional::CosineSimilarityFuncOptions().dim(0));

tests/util/util.h

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ bool cosineSimEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tens
2121

2222
bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = ATOL, float rtol = RTOL);
2323

24+
bool sameShape(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor);
25+
2426
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
2527

2628
void pointwise_test_helper(

0 commit comments

Comments
 (0)