Skip to content

Fix integer overflow in cuda NonMaxSuppression implementation #2540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const {
std::vector<std::tuple<IAllocatorUniquePtr<void>, int>> all_selected_indices;
int total_num_saved_outputs = 0;

// safe downcast max_output_boxes_per_class to int as cub::DeviceSelect::Flagged() does not support int64_t
int int_max_output_boxes_per_class = max_output_boxes_per_class > std::numeric_limits<int>::max()
? std::numeric_limits<int>::max()
: static_cast<int>(max_output_boxes_per_class);

for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) {
for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) {
IAllocatorUniquePtr<void> d_selected_indices{};
Expand All @@ -66,7 +71,7 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const {
GetCenterPointBox(),
batch_index,
class_index,
max_output_boxes_per_class,
int_max_output_boxes_per_class,
iou_threshold,
score_threshold,
d_selected_indices,
Expand Down Expand Up @@ -130,4 +135,4 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const {
}

} // namespace cuda
}; // namespace onnxruntime
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ constexpr int kNmsBlockDim = 16;
constexpr int kNmsBlockDimMax = 128;
constexpr int kNmsChunkSize = 2000;

template <typename T>
__device__ inline void Swap(T& a, T& b) {
T c(a);
a = b;
b = c;
}

// Check whether two boxes have an IoU greater than threshold.
template <typename T>
__device__ inline bool OverThreshold(const Box* a, const Box* b,
Expand All @@ -88,10 +81,6 @@ __device__ inline bool OverThreshold(const Box* a, const Box* b,
return aa >= bt;
}

__device__ inline void Flipped(Box& box) {
if (box.x1 > box.x2) Swap(box.x1, box.x2);
if (box.y1 > box.y2) Swap(box.y1, box.y2);
}
template <typename T>
__device__ inline bool CheckBit(T* bit_mask, int bit) {
constexpr int kShiftLen = NumBits(8 * sizeof(T)) - 1;
Expand All @@ -104,7 +93,7 @@ __device__ inline bool CheckBit(T* bit_mask, int bit) {
// generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask
// is num_boxes*bit_mask_len bits indicating whether to keep or remove a box.
__global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
const int num_boxes, const int64_t max_boxes,
const int num_boxes, const int max_boxes,
char* result_mask) {
extern __shared__ int local[];

Expand Down Expand Up @@ -247,7 +236,7 @@ Status NmsGpu(std::function<IAllocatorUniquePtr<void>(size_t)> allocator,
const float iou_threshold,
int* d_selected_indices,
int* h_nkeep,
const int64_t max_boxes) {
const int max_boxes) {
// Making sure we respect the __align(16)__
// we promised to the compiler.
auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
Expand Down Expand Up @@ -337,7 +326,7 @@ Status NonMaxSuppressionImpl(
const int64_t center_point_box,
int64_t batch_index,
int64_t class_index,
int64_t max_output_boxes_per_class,
int max_output_boxes_per_class,
float iou_threshold,
float score_threshold,
IAllocatorUniquePtr<void>& selected_indices,
Expand Down Expand Up @@ -427,7 +416,7 @@ Status NonMaxSuppressionImpl(
CUDA_RETURN_IF_ERROR(cudaGetLastError());

// STEP 4. map back to sorted indices
*h_number_selected = std::min(*h_number_selected, (int)max_output_boxes_per_class);
*h_number_selected = std::min(*h_number_selected, max_output_boxes_per_class);
int num_to_keep = *h_number_selected;
if (num_to_keep > 0) {
IAllocatorUniquePtr<void> d_output_indices_ptr{allocator(num_to_keep * sizeof(int))};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Status NonMaxSuppressionImpl(
const int64_t center_point_box,
int64_t batch_index,
int64_t class_index,
int64_t max_output_boxes_per_class,
int max_output_boxes_per_class,
float iou_threshold,
float score_threshold,
IAllocatorUniquePtr<void>& selected_indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,25 @@ TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) {
test.Run();
}

TEST(NonMaxSuppressionOpTest, BigIntMaxOutputBoxesPerClass) {
OpTester test("NonMaxSuppression", 10, kOnnxDomain);
test.AddInput<float>("boxes", {1, 6, 4},
{0.0f, 0.0f, 1.0f, 1.0f,
0.0f, 0.1f, 1.0f, 1.1f,
0.0f, -0.1f, 1.0f, 0.9f,
0.0f, 10.0f, 1.0f, 11.0f,
0.0f, 10.1f, 1.0f, 11.1f,
0.0f, 100.0f, 1.0f, 101.0f});
test.AddInput<float>("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f});
test.AddInput<int64_t>("max_output_boxes_per_class", {}, {9223372036854775807L});
test.AddInput<float>("iou_threshold", {}, {0.5f});
test.AddInput<float>("score_threshold", {}, {0.4f});
test.AddOutput<int64_t>("selected_indices", {2, 3},
{0L, 0L, 3L,
0L, 0L, 0L});
test.Run();
}

TEST(NonMaxSuppressionOpTest, WithIOUThresholdOpset11) {
OpTester test("NonMaxSuppression", 11, kOnnxDomain);
test.AddInput<float>("boxes", {1, 6, 4},
Expand Down