@@ -5,21 +5,19 @@ namespace torch_tensorrt {
5
5
namespace tests {
6
6
namespace util {
7
7
8
- bool checkRtol (const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold) {
9
- double maxValue = 0.0 ;
10
- for (auto & tensor : inputs) {
11
- maxValue = fmax (tensor.abs ().max ().item <float >(), maxValue);
12
- }
13
- std::cout << " Max Difference: " << diff.abs ().max ().item <float >() << std::endl;
14
- std::cout << " Acceptable Threshold: " << threshold << std::endl;
15
- return diff.abs ().max ().item <float >() <= threshold * maxValue;
16
- }
17
8
18
- bool almostEqual (const at::Tensor& a, const at::Tensor& b, float threshold) {
9
+ bool almostEqual (const at::Tensor& a, const at::Tensor& b, float threshold, float atol= 1e-8 , float rtol= 1e-5 ) {
19
10
LOG_GRAPH (a << std::endl << b << std::endl);
20
11
auto a_float = a.toType (at::kFloat );
21
12
auto b_float = b.toType (at::kFloat );
22
- return checkRtol (a_float - b_float, {a_float, b_float}, threshold);
13
+
14
+ auto diff = a_float - b_float;
15
+ auto result = diff.abs ().max ().item <float >() - (atol + rtol * b.abs ().max ().item <float >());
16
+
17
+ std::cout << " Max Difference: " << result << std::endl;
18
+ std::cout << " Acceptable Threshold: " << threshold << std::endl;
19
+
20
+ return result <= threshold;
23
21
}
24
22
25
23
bool exactlyEqual (const at::Tensor& a, const at::Tensor& b) {
0 commit comments