13
13
#include < sys/stat.h>
14
14
15
15
int main (int argc, const char * argv[]) {
16
- trtorch::logging::set_reportable_log_level (trtorch::logging::kINFO );
16
+ trtorch::logging::set_reportable_log_level (trtorch::logging::Level:: kERROR );
17
17
if (argc < 3 ) {
18
18
std::cerr << " usage: ptq <path-to-module> <path-to-cifar10>\n " ;
19
19
return -1 ;
@@ -50,11 +50,13 @@ int main(int argc, const char* argv[]) {
50
50
// Configure settings for compilation
51
51
auto extra_info = trtorch::ExtraInfo ({input_shape});
52
52
// Set operating precision to INT8
53
- extra_info.op_precision = torch::kFI8 ;
53
+ extra_info.op_precision = torch::kI8 ;
54
54
// Use the TensorRT Entropy Calibrator
55
55
extra_info.ptq_calibrator = calibrator;
56
56
// Set max batch size for the engine
57
57
extra_info.max_batch_size = 32 ;
58
+ // Set a larger workspace
59
+ extra_info.workspace_size = 1 << 28 ;
58
60
59
61
mod.eval ();
60
62
@@ -82,6 +84,7 @@ int main(int argc, const char* argv[]) {
82
84
std::cout << " Accuracy of JIT model on test set: " << 100 * (correct / total) << " %" << std::endl;
83
85
84
86
// Compile Graph
87
+ std::cout << " Compiling and quantizing module" << std::endl;
85
88
auto trt_mod = trtorch::CompileGraph (mod, extra_info);
86
89
87
90
// Check the INT8 accuracy in TRT
@@ -91,22 +94,27 @@ int main(int argc, const char* argv[]) {
91
94
auto images = batch.data .to (torch::kCUDA );
92
95
auto targets = batch.target .to (torch::kCUDA );
93
96
97
+ if (images.sizes ()[0 ] < 32 ) {
98
+ // To handle smaller batches util Optimization profiles work with Int8
99
+ auto diff = 32 - images.sizes ()[0 ];
100
+ auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
101
+ auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
102
+ images = torch::cat ({images, img_padding}, 0 );
103
+ targets = torch::cat ({targets, target_padding}, 0 );
104
+ }
105
+
94
106
auto outputs = trt_mod.forward ({images});
95
107
auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
96
108
predictions = predictions.reshape (predictions.sizes ()[0 ]);
97
109
98
110
if (predictions.sizes ()[0 ] != targets.sizes ()[0 ]) {
99
- // To handle smaller batches util Optimization profiles work
111
+ // To handle smaller batches util Optimization profiles work with Int8
100
112
predictions = predictions.slice (0 , 0 , targets.sizes ()[0 ]);
101
113
}
102
114
103
- std:: cout << predictions << targets << std::endl;
104
-
105
115
total += targets.sizes ()[0 ];
106
116
correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
107
- std::cout << total << " " << correct << std::endl;
108
117
}
109
- std::cout << total << " " << correct << std::endl;
110
118
std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %" << std::endl;
111
119
112
120
// Time execution in INT8
0 commit comments