@@ -37,7 +37,7 @@ bool AdaptivePoolingConverter(
37
37
ConversionCtx* ctx,
38
38
const torch::jit::Node* n,
39
39
args& args,
40
- nvinfer1::PoolingType pool_type) {
40
+ nvinfer1::PoolingType pool_type, const std::string& mode ) {
41
41
auto in = args[0 ].ITensorOrFreeze (ctx);
42
42
auto out_size = util::toDims (args[1 ].unwrapToIntList ());
43
43
@@ -48,15 +48,7 @@ bool AdaptivePoolingConverter(
48
48
}
49
49
50
50
auto orig_dims = in->getDimensions ();
51
- bool expandDims = (orig_dims.nbDims < 4 );
52
- TORCHTRT_CHECK (orig_dims.nbDims > 2 , " Unable to create pooling layer from node: " << *n);
53
- if (expandDims) {
54
- in = addPadding (ctx, n, in, 4 , false , false );
55
- }
56
-
57
- if (out_size.nbDims == 1 ) {
58
- out_size = util::unsqueezeDims (out_size, 0 , 1 );
59
- }
51
+ TORCHTRT_CHECK (orig_dims.nbDims > 1 , " Unable to create pooling layer from node: " << *n);
60
52
61
53
auto in_shape = util::toVec (in->getDimensions ());
62
54
nvinfer1::ILayer* new_layer = nullptr ;
@@ -90,10 +82,6 @@ bool AdaptivePoolingConverter(
90
82
int32_t use_scales_casted = 0 ;
91
83
f.emplace_back (nvinfer1::PluginField (" use_scales" , &use_scales_casted, nvinfer1::PluginFieldType::kINT32 , 1 ));
92
84
93
- std::string mode = " adaptive_avg_pool2d" ;
94
- if (pool_type == nvinfer1::PoolingType::kMAX ) {
95
- mode = " adaptive_max_pool2d" ;
96
- }
97
85
f.emplace_back (nvinfer1::PluginField (" mode" , &mode, nvinfer1::PluginFieldType::kCHAR , 1 ));
98
86
99
87
fc.nbFields = f.size ();
@@ -110,7 +98,7 @@ bool AdaptivePoolingConverter(
110
98
TORCHTRT_CHECK (new_layer, " Unable to create pooling (interpolation) plugin from node" << *n);
111
99
112
100
new_layer->setName (util::node_info (n).c_str ());
113
- auto layer_output = addUnpadding (ctx, n, new_layer->getOutput (0 ), orig_dims. nbDims , false , false );
101
+ auto layer_output = new_layer->getOutput (0 );
114
102
115
103
ctx->AssociateValueAndTensor (n->outputs ()[0 ], layer_output);
116
104
LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
@@ -238,15 +226,15 @@ auto pooling_registrations TORCHTRT_UNUSED =
238
226
}})
239
227
.pattern({" aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)" ,
240
228
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
241
- return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kAVERAGE );
229
+ return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kAVERAGE , " adaptive_avg_pool1d " );
242
230
}})
243
231
.pattern({" aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)" ,
244
232
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
245
- return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kAVERAGE );
233
+ return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kAVERAGE , " adaptive_avg_pool2d " );
246
234
}})
247
235
.pattern({" aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)" ,
248
236
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
249
- return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kMAX );
237
+ return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kMAX , " adaptive_max_pool2d " );
250
238
}});
251
239
} // namespace
252
240
} // namespace impl
0 commit comments