@@ -180,15 +180,50 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use
180
180
return dims;
181
181
}
182
182
183
- nvinfer1::Dims squeezeDims (const nvinfer1::Dims& d, int pos, bool use_zeros) {
183
+ int validateInputDimsForShuffle (const nvinfer1::Dims& d, bool input_is_dynamic) {
184
+ int num_zeros_detected = 0 ;
185
+
186
+ // For each dimension, increment counter if that dimension has value 0
187
+ for (int i = 0 ; i < d.nbDims ; i++) {
188
+ if (d.d [i] == 0 ) {
189
+ num_zeros_detected++;
190
+ }
191
+ }
192
+
193
+ // If the tensor from which the dimensions originate has dynamic shape and more than 1
194
+ // zero dimension is detected, this constitutes an invalid shape to the TRT Shuffle Layer,
195
+ // since dynamic dimensions to Shuffle Layers are generally represented with a 0
196
+ // denoting to inherit the dimension from the input tensor, thus causing an
197
+ // overload of the "0" dimension
198
+ return (input_is_dynamic && num_zeros_detected > 1 ) ? -1 : num_zeros_detected;
199
+ }
200
+
201
+ nvinfer1::Dims squeezeDims (const nvinfer1::Dims& d, int pos, bool use_zeros, bool swap_existing_zeros) {
184
202
// acceptable range for pos is [0, d.nbDims]
185
203
TORCHTRT_ASSERT (pos >= 0 && pos <= d.nbDims , " ERROR: Index to squeeze is out of bounds." );
186
204
187
205
nvinfer1::Dims dims;
188
206
int j = 0 ;
189
207
for (int i = 0 ; i < d.nbDims ; i++) {
190
208
if (i != pos) {
191
- dims.d [j++] = (use_zeros && d.d [i] == -1 ) ? 0 : d.d [i];
209
+ // If zeros are replacing dynamic/existing dimensions,
210
+ // Replace all instances of -1, indicating dynamic dimension
211
+ // with 0, indicating copy the dimension from another tensor
212
+ // (Generally used for reshape operations)
213
+ if (use_zeros && d.d [i] == -1 ) {
214
+ dims.d [j] = 0 ;
215
+ // If zeros already exist in the dimensions (empty tensor),
216
+ // Replace all instances of 0, indicating empty dimension
217
+ // with -1, indicating inherit the dimension from reshape
218
+ // (Generally used for reshape operations)
219
+ } else if (swap_existing_zeros && d.d [i] == 0 ) {
220
+ dims.d [j] = -1 ;
221
+ // Otherwise, replace the dimension with the same value from the input
222
+ } else {
223
+ dims.d [j] = d.d [i];
224
+ }
225
+
226
+ j++;
192
227
}
193
228
}
194
229
dims.nbDims = j;
0 commit comments