forked from pytorch/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.h
77 lines (61 loc) · 2.77 KB
/
util.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#pragma once
#include <ATen/ATen.h>
#include <string>
#include <vector>
#include "ATen/Tensor.h"
#include "core/ir/ir.h"
#include "core/util/prelude.h"
#include "torch/csrc/jit/ir/irparser.h"
const float ATOL = 1e-8;
const float RTOL = 1e-5;
const float COSINE_THRESHOLD = 0.99f;
const float THRESHOLD_E5 = 1e-5;
namespace torch_tensorrt {
namespace tests {
namespace util {
bool cosineSimEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float threshold = COSINE_THRESHOLD);
bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = ATOL, float rtol = RTOL);
bool sameShape(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor);
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
void pointwise_test_helper(
std::string graph_ir,
bool singleInput,
bool dynamicInput = false,
std::vector<int64_t> shape1 = {5},
std::vector<int64_t> shape2 = {5},
bool negative_input = false,
at::ScalarType type1 = at::kFloat,
at::ScalarType type2 = at::kFloat);
std::vector<at::Tensor> RunEngine(std::string& eng, std::vector<at::Tensor> inputs);
// Runs an arbitrary JIT graph and returns results
std::vector<at::Tensor> RunGraph(
std::shared_ptr<torch::jit::Graph>& g,
core::ir::StaticParams& named_params,
std::vector<at::Tensor> inputs);
// Runs an arbitrary JIT graph by converting it to TensorRT and running
// inference and returns results
std::vector<at::Tensor> RunGraphEngine(
std::shared_ptr<torch::jit::Graph>& g,
core::ir::StaticParams& named_params,
std::vector<at::Tensor> inputs,
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT);
// Runs an arbitrary JIT graph with dynamic input sizes by converting it to
// TensorRT and running inference and returns results
std::vector<at::Tensor> RunGraphEngineDynamic(
std::shared_ptr<torch::jit::Graph>& g,
core::ir::StaticParams& named_params,
std::vector<at::Tensor> inputs,
bool dynamic_batch = false);
// Run the forward method of a module and return results
torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector<torch::jit::IValue> inputs);
// Convert the forward module to a TRT engine and return results
std::vector<at::Tensor> RunModuleForwardAsEngine(torch::jit::Module& mod, std::vector<at::Tensor> inputs);
// Runs evaluatable graphs through the compiler evaluator library and returns results
std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::vector<torch::jit::IValue> inputs);
// Runs evaluatable graphs through the JIT interpreter and returns results
std::vector<torch::jit::IValue> EvaluateGraphJIT(
std::shared_ptr<torch::jit::Graph>& g,
std::vector<torch::jit::IValue> inputs);
} // namespace util
} // namespace tests
} // namespace torch_tensorrt