Skip to content

Commit df74136

Browse files
committed
feat(//tests): New optional accuracy tests to check INT8 and FP16
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent b989c7f commit df74136

13 files changed

+508
-56
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ py/.eggs
2121
cpp/ptq/training/vgg16/data/*
2222
*.bin
2323
cpp/ptq/datasets/data/
24+
tests/accuracy/datasets/data/*
2425
._.DS_Store
2526
*.tar.gz

tests/BUILD

+8-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,11 @@ test_suite(
55
"//tests/modules:test_modules"
66
],
77
)
8-
8+
9+
test_suite(
10+
name = "required_and_optional_tests",
11+
tests = [
12+
":tests",
13+
"//tests/accuracy:test_accuracy"
14+
]
15+
)

tests/accuracy/BUILD

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
filegroup(
2+
name = "jit_models",
3+
srcs = glob(["**/*.jit.pt"])
4+
)
5+
6+
test_suite(
7+
name = "test_accuracy",
8+
tests = [
9+
":test_int8_accuracy",
10+
":test_fp16_accuracy",
11+
":test_fp32_accuracy",
12+
]
13+
)
14+
15+
cc_test(
16+
name = "test_int8_accuracy",
17+
srcs = ["test_int8_accuracy.cpp"],
18+
deps = [
19+
":accuracy_test",
20+
"//tests/accuracy/datasets:cifar10"
21+
],
22+
data = [
23+
":jit_models",
24+
]
25+
)
26+
27+
cc_test(
28+
name = "test_fp16_accuracy",
29+
srcs = ["test_fp16_accuracy.cpp"],
30+
deps = [
31+
":accuracy_test",
32+
"//tests/accuracy/datasets:cifar10"
33+
],
34+
data = [
35+
":jit_models",
36+
]
37+
)
38+
39+
cc_test(
40+
name = "test_fp32_accuracy",
41+
srcs = ["test_fp32_accuracy.cpp"],
42+
deps = [
43+
":accuracy_test",
44+
"//tests/accuracy/datasets:cifar10"
45+
],
46+
data = [
47+
":jit_models",
48+
]
49+
)
50+
51+
cc_binary(
52+
name = "test",
53+
srcs = ["test.cpp"],
54+
deps = [
55+
":accuracy_test",
56+
"//tests/accuracy/datasets:cifar10"
57+
],
58+
data = [
59+
":jit_models",
60+
]
61+
)
62+
63+
64+
cc_library(
65+
name = "accuracy_test",
66+
hdrs = ["accuracy_test.h"],
67+
deps = [
68+
"//cpp/api:trtorch",
69+
"//tests/util",
70+
"@libtorch//:libtorch",
71+
"@googletest//:gtest_main",
72+
],
73+
)

tests/accuracy/accuracy_test.h

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <utility>
2+
#include "torch/script.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "trtorch/trtorch.h"
6+
#include "c10/cuda/CUDACachingAllocator.h"
7+
8+
// TODO: Extend this to support other datasets
9+
class AccuracyTests
10+
: public testing::TestWithParam<std::string> {
11+
public:
12+
void SetUp() override {
13+
auto params = GetParam();
14+
auto module_path = params;
15+
try {
16+
// Deserialize the ScriptModule from a file using torch::jit::load().
17+
mod = torch::jit::load(module_path);
18+
}
19+
catch (const c10::Error& e) {
20+
std::cerr << "error loading the model\n";
21+
return;
22+
}
23+
}
24+
25+
void TearDown() {
26+
c10::cuda::CUDACachingAllocator::emptyCache();
27+
}
28+
protected:
29+
torch::jit::script::Module mod;
30+
};

tests/accuracy/datasets/BUILD

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
cc_library(
4+
name = "cifar10",
5+
hdrs = [
6+
"cifar10.h"
7+
],
8+
srcs = [
9+
"cifar10.cpp"
10+
],
11+
deps = [
12+
"@libtorch//:libtorch"
13+
],
14+
data = [
15+
":cifar10_data"
16+
]
17+
18+
)
19+
20+
filegroup(
21+
name = "cifar10_data",
22+
srcs = glob(["data/cifar-10-batches-bin/**/*.bin"])
23+
)

tests/accuracy/datasets/cifar10.cpp

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#include "tests/accuracy/datasets/cifar10.h"
2+
3+
#include "torch/torch.h"
4+
#include "torch/data/example.h"
5+
#include "torch/types.h"
6+
7+
#include <iostream>
8+
#include <cstddef>
9+
#include <fstream>
10+
#include <string>
11+
#include <vector>
12+
#include <utility>
13+
#include <sstream>
14+
#include <memory>
15+
16+
namespace datasets {
17+
namespace {
18+
constexpr const char* kTrainFilenamePrefix = "data_batch_";
19+
constexpr const uint32_t kNumTrainFiles = 5;
20+
constexpr const char* kTestFilename = "test_batch.bin";
21+
constexpr const size_t kLabelSize = 1; // B
22+
constexpr const size_t kImageSize = 3072; // B
23+
constexpr const size_t kImageDim = 32;
24+
constexpr const size_t kImageChannels = 3;
25+
constexpr const size_t kBatchSize = 10000;
26+
27+
std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) {
28+
std::ifstream batch;
29+
batch.open(path, std::ios::in|std::ios::binary|std::ios::ate);
30+
31+
auto file_size = batch.tellg();
32+
std::unique_ptr<char[]> buf(new char[file_size]);
33+
34+
batch.seekg(0, std::ios::beg);
35+
batch.read(buf.get(), file_size);
36+
batch.close();
37+
38+
std::vector<uint8_t> labels;
39+
std::vector<torch::Tensor> images;
40+
labels.reserve(kBatchSize);
41+
images.reserve(kBatchSize);
42+
43+
for (size_t i = 0; i < kBatchSize; i++) {
44+
uint8_t label = buf[i * (kImageSize + kLabelSize)];
45+
std::vector<uint8_t> image;
46+
image.reserve(kImageSize);
47+
std::copy(&buf[i * (kImageSize + kLabelSize) + 1], &buf[i * (kImageSize + kLabelSize) + kImageSize], std::back_inserter(image));
48+
labels.push_back(label);
49+
auto image_tensor = torch::from_blob(image.data(),
50+
{kImageChannels, kImageDim, kImageDim},
51+
torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32);
52+
images.push_back(image_tensor);
53+
}
54+
55+
auto labels_tensor = torch::from_blob(labels.data(),
56+
{kBatchSize},
57+
torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32);
58+
assert(labels_tensor.size(0) == kBatchSize);
59+
60+
auto images_tensor = torch::stack(images);
61+
assert(images_tensor.size(0) == kBatchSize);
62+
63+
return std::make_pair(images_tensor, labels_tensor);
64+
}
65+
66+
std::pair<torch::Tensor, torch::Tensor> read_train_data(const std::string& root) {
67+
std::vector<torch::Tensor> images, targets;
68+
for(uint32_t i = 1; i <= 5; i++) {
69+
std::stringstream ss;
70+
ss << root << '/' << kTrainFilenamePrefix << i << ".bin";
71+
auto batch = read_batch(ss.str());
72+
images.push_back(batch.first);
73+
targets.push_back(batch.second);
74+
}
75+
76+
torch::Tensor image_tensor = std::accumulate(++images.begin(), images.end(), *images.begin(), [&](torch::Tensor a, torch::Tensor b) {return torch::cat({a, b}, 0);});
77+
torch::Tensor target_tensor = std::accumulate(++targets.begin(), targets.end(), *targets.begin(), [&](torch::Tensor a, torch::Tensor b) {return torch::cat({a, b}, 0);});
78+
79+
return std::make_pair(image_tensor, target_tensor);
80+
}
81+
82+
std::pair<torch::Tensor, torch::Tensor> read_test_data(const std::string& root) {
83+
std::stringstream ss;
84+
ss << root << '/' << kTestFilename;
85+
return read_batch(ss.str());
86+
}
87+
}
88+
89+
CIFAR10::CIFAR10(const std::string& root, Mode mode)
90+
: mode_(mode) {
91+
92+
std::pair<torch::Tensor, torch::Tensor> data;
93+
if (mode_ == Mode::kTrain) {
94+
data = read_train_data(root);
95+
} else {
96+
data = read_test_data(root);
97+
}
98+
99+
images_ = std::move(data.first);
100+
targets_ = std::move(data.second);
101+
assert(images_.sizes()[0] == images_.sizes()[0]);
102+
}
103+
104+
torch::data::Example<> CIFAR10::get(size_t index) {
105+
return {images_[index], targets_[index]};
106+
}
107+
108+
c10::optional<size_t> CIFAR10::size() const {
109+
return images_.size(0);
110+
}
111+
112+
bool CIFAR10::is_train() const noexcept {
113+
return mode_ == Mode::kTrain;
114+
}
115+
116+
const torch::Tensor& CIFAR10::images() const {
117+
return images_;
118+
}
119+
120+
const torch::Tensor& CIFAR10::targets() const {
121+
return targets_;
122+
}
123+
124+
CIFAR10&& CIFAR10::use_subset(int64_t new_size) {
125+
assert(new_size <= images_.sizes()[0]);
126+
images_ = images_.slice(0, 0, new_size);
127+
targets_ = targets_.slice(0, 0, new_size);
128+
return std::move(*this);
129+
}
130+
131+
} // namespace datasets
132+

tests/accuracy/datasets/cifar10.h

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
#include "torch/data/datasets/base.h"
4+
#include "torch/data/example.h"
5+
#include "torch/types.h"
6+
7+
#include <cstddef>
8+
#include <string>
9+
10+
namespace datasets {
11+
// The CIFAR10 Dataset
12+
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
13+
public:
14+
// The mode in which the dataset is loaded
15+
enum class Mode { kTrain, kTest };
16+
17+
// Loads CIFAR10 from un-tarred file
18+
// Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
19+
// Root path should be the directory that contains the content of tarball
20+
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);
21+
22+
// Returns the pair at index in the dataset
23+
torch::data::Example<> get(size_t index) override;
24+
25+
// The size of the dataset
26+
c10::optional<size_t> size() const override;
27+
28+
// The mode the dataset is in
29+
bool is_train() const noexcept;
30+
31+
// Returns all images stacked into a single tensor
32+
const torch::Tensor& images() const;
33+
34+
// Returns all targets stacked into a single tensor
35+
const torch::Tensor& targets() const;
36+
37+
// Trims the dataset to the first n pairs
38+
CIFAR10&& use_subset(int64_t new_size);
39+
40+
41+
private:
42+
Mode mode_;
43+
torch::Tensor images_, targets_;
44+
};
45+
} // namespace datasets

tests/accuracy/test_fp16_accuracy.cpp

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include "accuracy_test.h"
2+
#include "datasets/cifar10.h"
3+
#include "torch/torch.h"
4+
5+
TEST_P(AccuracyTests, FP16AccuracyIsClose) {
6+
auto eval_dataset = datasets::CIFAR10("tests/accuracy/datasets/data/cifar-10-batches-bin/", datasets::CIFAR10::Mode::kTest)
7+
.use_subset(3200)
8+
.map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465},
9+
{0.2023, 0.1994, 0.2010}))
10+
.map(torch::data::transforms::Stack<>());
11+
auto eval_dataloader = torch::data::make_data_loader(std::move(eval_dataset), torch::data::DataLoaderOptions()
12+
.batch_size(32)
13+
.workers(2));
14+
15+
// Check the FP32 accuracy in JIT
16+
torch::Tensor jit_correct = torch::zeros({1}, {torch::kCUDA}), jit_total = torch::zeros({1}, {torch::kCUDA});
17+
for (auto batch : *eval_dataloader) {
18+
auto images = batch.data.to(torch::kCUDA);
19+
auto targets = batch.target.to(torch::kCUDA);
20+
21+
auto outputs = mod.forward({images});
22+
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
23+
24+
jit_total += targets.sizes()[0];
25+
jit_correct += torch::sum(torch::eq(predictions, targets));
26+
}
27+
torch::Tensor jit_accuracy = jit_correct / jit_total;
28+
29+
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
30+
auto extra_info = trtorch::ExtraInfo({input_shape});
31+
extra_info.op_precision = torch::kF16;
32+
33+
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
34+
35+
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
36+
for (auto batch : *eval_dataloader) {
37+
auto images = batch.data.to(torch::kCUDA).to(torch::kF16);
38+
auto targets = batch.target.to(torch::kCUDA).to(torch::kF16);
39+
40+
auto outputs = trt_mod.forward({images});
41+
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
42+
predictions = predictions.reshape(predictions.sizes()[0]);
43+
44+
trt_total += targets.sizes()[0];
45+
trt_correct += torch::sum(torch::eq(predictions, targets));
46+
}
47+
48+
torch::Tensor trt_accuracy = trt_correct / trt_total;
49+
50+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3));
51+
}
52+
53+
54+
INSTANTIATE_TEST_SUITE_P(FP16AccuracyIsCloseSuite,
55+
AccuracyTests,
56+
testing::Values("tests/accuracy/vgg16_cifar10.jit.pt"));

0 commit comments

Comments
 (0)