Skip to content

Commit 5c0d737

Browse files
committed
feat(/cpp/api): Working INT8 Calibrator, also resolves #41
- Now creates output tensors of the correct type to accept data - There still may be a data race in the creation of the dataloader iterator - Quantization and Dynamic Shape right now don't play well together, potential subsequent release of TRT may address this Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 5f36f47 commit 5c0d737

File tree

7 files changed

+48
-29
lines changed

7 files changed

+48
-29
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ auto results = trt_mod.forward({in_tensor});
2727
2828
> Notes on running in lower precisions:
2929
> - Set precision with extra_info.op_precision
30-
> - The module should be left in FP32 before compilation
30+
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
3131
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
3232
3333
## Platform Support

core/conversion/conversion.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void AddInputs(ConversionCtx* ctx,
133133
"Expected dimension specifications for all input tensors" \
134134
<< ", but found " << input_tensors.size() \
135135
<< " input tensors and " \
136-
<< input_dims.size() << "dimension specs (conversion.AddInputs)");
136+
<< input_dims.size() << " dimension specs (conversion.AddInputs)");
137137

138138
auto profile = ctx->builder->createOptimizationProfile();
139139

@@ -235,7 +235,7 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
235235
if (!OpSupported(n)) {
236236
auto schema = n->maybeSchema();
237237
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
238-
<< " (conversion.AddLayer)");
238+
<< " (conversion.VerifyCoverterSupportForBloxk");
239239
std::stringstream ss;
240240
ss << *schema;
241241
unsupported_ops.insert(ss.str());

core/conversion/conversionctx/ConversionCtx.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5151
case nvinfer1::DataType::kINT8:
5252
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
5353
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
54-
input_type = nvinfer1::DataType::kINT8;
54+
input_type = nvinfer1::DataType::kFLOAT;
5555
// If the calibrator is nullptr then TRT will use default quantization
5656
cfg->setInt8Calibrator(settings.calibrator);
5757
break;

core/execution/register_trt_op.cpp

+4-12
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,11 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
1717
contig_inputs.reserve(inputs.size());
1818
for (size_t i = 0; i < inputs.size(); i++) {
1919
TRTORCH_CHECK(inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
20-
auto expected_type = torch::kF32;
21-
switch (ctx->getEngine().getBindingDataType(i)) {
22-
case nvinfer1::DataType::kHALF:
23-
expected_type = torch::kF16;
24-
break;
25-
case nvinfer1::DataType::kFLOAT:
26-
case nvinfer1::DataType::kINT8:
27-
default:
28-
expected_type = torch::kF32;
29-
}
20+
auto expected_type = util::toATenDType(ctx->getEngine().getBindingDataType(i));
3021
TRTORCH_CHECK(inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
3122
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
3223
auto shape = core::util::toVec(dims);
33-
contig_inputs.push_back(inputs[i].to(at::kCUDA).view(shape).contiguous());
24+
contig_inputs.push_back(inputs[i].view(shape).contiguous());
3425
LOG_DEBUG("In shape:" << shape);
3526
ctx->setBindingDimensions(i, dims);
3627
gpu_handles.push_back(contig_inputs.back().data_ptr());
@@ -43,7 +34,8 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
4334
auto out_shape = ctx->getBindingDimensions(o);
4435
//LOG_DEBUG("Output: " << engine->getBindingName(o) << " out shape: " << out_shape);
4536
auto dims = core::util::toVec(out_shape);
46-
outputs.push_back(at::empty(dims, {at::kCUDA}).contiguous());
37+
auto type = util::toATenDType(ctx->getEngine().getBindingDataType(o));
38+
outputs.push_back(at::empty(dims, {at::kCUDA}).to(type).contiguous());
4739
gpu_handles.push_back(outputs[outputs.size() - 1].data_ptr());
4840
}
4941

core/util/trt_util.cpp

+22-4
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
1515
LOG_DEBUG("Requested padding of dimensions to " << pad_to << " but found " << l.size() << " dimensions, not going to pad");
1616
return toDims(l);
1717
}
18-
18+
1919
if (pad_to > nvinfer1::Dims::MAX_DIMS) {
2020
//TODO: Handle this with exceptions or whatever
2121
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
2222
}
23-
23+
2424
nvinfer1::Dims dims;
2525
dims.nbDims = pad_to;
2626
for (size_t i = 0; i < pad_to - l.size(); i++) {
2727
dims.d[i] = 1;
2828
}
29-
29+
3030
for (size_t i = pad_to - l.size(); i < pad_to; i++) {
3131
dims.d[i] = l[i - (pad_to - l.size())];
3232
}
@@ -58,7 +58,7 @@ nvinfer1::Dims toDims(c10::List<int64_t> l) {
5858
}
5959
return dims;
6060
}
61-
61+
6262
std::vector<int64_t> toVec(nvinfer1::Dims d) {
6363
std::vector<int64_t> dims;
6464
for (int i = 0; i < d.nbDims; i++) {
@@ -110,8 +110,26 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
110110
};
111111
return at_trt_type_map;
112112
}
113+
114+
const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_map() {
115+
static const std::unordered_map<nvinfer1::DataType, at::ScalarType> trt_at_type_map = {
116+
{nvinfer1::DataType::kFLOAT, at::kFloat},
117+
{nvinfer1::DataType::kHALF, at::kHalf},
118+
{nvinfer1::DataType::kINT32, at::kInt},
119+
{nvinfer1::DataType::kINT8, at::kChar},
120+
};
121+
return trt_at_type_map;
122+
}
113123
} // namespace
114124

125+
const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_aten_type_map() {
126+
return get_trt_at_type_map();
127+
}
128+
129+
at::ScalarType toATenDType(nvinfer1::DataType t) {
130+
return get_trt_aten_type_map().at(t);
131+
}
132+
115133
const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_aten_trt_type_map() {
116134
return get_at_trt_type_map();
117135
}

core/util/trt_util.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ inline bool operator==(const nvinfer1::Dims& in1, const nvinfer1::Dims& in2) {
2121
}
2222

2323
// TODO maybe look to support broadcasting comparisons
24-
24+
2525
for (int64_t i = 0; i < in1.nbDims; i++) {
2626
if (in1.d[i] != in2.d[i]) {
2727
return false;
@@ -85,11 +85,12 @@ nvinfer1::DimsHW toDimsHW(c10::IntArrayRef l);
8585
std::vector<int64_t> toVec(nvinfer1::Dims d);
8686
std::string toStr(nvinfer1::Dims d);
8787

88+
at::ScalarType toATenDType(nvinfer1::DataType t);
8889
nvinfer1::DataType toTRTDataType(at::ScalarType t);
8990
c10::optional<nvinfer1::DataType>toTRTDataType(caffe2::TypeMeta dtype);
9091

9192
const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_aten_trt_type_map();
92-
93+
9394
} // namespace util
9495
} // namespace core
9596
} // namespace trtorch

cpp/ptq/main.cpp

+15-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <sys/stat.h>
1414

1515
int main(int argc, const char* argv[]) {
16-
trtorch::logging::set_reportable_log_level(trtorch::logging::kINFO);
16+
trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kERROR);
1717
if (argc < 3) {
1818
std::cerr << "usage: ptq <path-to-module> <path-to-cifar10>\n";
1919
return -1;
@@ -50,11 +50,13 @@ int main(int argc, const char* argv[]) {
5050
// Configure settings for compilation
5151
auto extra_info = trtorch::ExtraInfo({input_shape});
5252
// Set operating precision to INT8
53-
extra_info.op_precision = torch::kFI8;
53+
extra_info.op_precision = torch::kI8;
5454
// Use the TensorRT Entropy Calibrator
5555
extra_info.ptq_calibrator = calibrator;
5656
// Set max batch size for the engine
5757
extra_info.max_batch_size = 32;
58+
// Set a larger workspace
59+
extra_info.workspace_size = 1 << 28;
5860

5961
mod.eval();
6062

@@ -82,6 +84,7 @@ int main(int argc, const char* argv[]) {
8284
std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" << std::endl;
8385

8486
// Compile Graph
87+
std::cout << "Compiling and quantizing module" << std::endl;
8588
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
8689

8790
// Check the INT8 accuracy in TRT
@@ -91,22 +94,27 @@ int main(int argc, const char* argv[]) {
9194
auto images = batch.data.to(torch::kCUDA);
9295
auto targets = batch.target.to(torch::kCUDA);
9396

97+
if (images.sizes()[0] < 32) {
98+
// To handle smaller batches util Optimization profiles work with Int8
99+
auto diff = 32 - images.sizes()[0];
100+
auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA});
101+
auto target_padding = torch::zeros({diff}, {torch::kCUDA});
102+
images = torch::cat({images, img_padding}, 0);
103+
targets = torch::cat({targets, target_padding}, 0);
104+
}
105+
94106
auto outputs = trt_mod.forward({images});
95107
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
96108
predictions = predictions.reshape(predictions.sizes()[0]);
97109

98110
if (predictions.sizes()[0] != targets.sizes()[0]) {
99-
// To handle smaller batches util Optimization profiles work
111+
// To handle smaller batches util Optimization profiles work with Int8
100112
predictions = predictions.slice(0, 0, targets.sizes()[0]);
101113
}
102114

103-
std:: cout << predictions << targets << std::endl;
104-
105115
total += targets.sizes()[0];
106116
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
107-
std::cout << total << " " << correct << std::endl;
108117
}
109-
std::cout << total << " " << correct << std::endl;
110118
std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl;
111119

112120
// Time execution in INT8

0 commit comments

Comments
 (0)