Skip to content

Commit 0b0ba8d

Browse files
committed
fix: Considering rtol and atol in threshold comparison for floating point numbers
Signed-off-by: Anurag Dixit <[email protected]>
1 parent ef62f6b commit 0b0ba8d

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

tests/util/util.cpp

+9-11
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,19 @@ namespace torch_tensorrt {
55
namespace tests {
66
namespace util {
77

8-
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold) {
9-
double maxValue = 0.0;
10-
for (auto& tensor : inputs) {
11-
maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
12-
}
13-
std::cout << "Max Difference: " << diff.abs().max().item<float>() << std::endl;
14-
std::cout << "Acceptable Threshold: " << threshold << std::endl;
15-
return diff.abs().max().item<float>() <= threshold * maxValue;
16-
}
178

18-
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) {
9+
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol=1e-8, float rtol=1e-5) {
1910
LOG_GRAPH(a << std::endl << b << std::endl);
2011
auto a_float = a.toType(at::kFloat);
2112
auto b_float = b.toType(at::kFloat);
22-
return checkRtol(a_float - b_float, {a_float, b_float}, threshold);
13+
14+
auto diff = a_float - b_float;
15+
auto result = diff.abs().max().item<float>() - (atol + rtol * b.abs().max().item<float>());
16+
17+
std::cout << "Max Difference: " << result << std::endl;
18+
std::cout << "Acceptable Threshold: " << threshold << std::endl;
19+
20+
return result <= threshold;
2321
}
2422

2523
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {

tests/util/util.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace torch_tensorrt {
1111
namespace tests {
1212
namespace util {
1313

14-
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold);
14+
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol=1e-8, float rtol=1e-5);
1515

1616
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
1717

0 commit comments

Comments
 (0)