Skip to content

Commit 0748f1d

Browse files
authored
Merge pull request #165 from stanstarks/master
add multi-label nms for FCOS post process
2 parents ebdd88b + 2cd10da commit 0748f1d

File tree

8 files changed

+200
-28
lines changed

8 files changed

+200
-28
lines changed

fcos_core/csrc/cuda/ml_nms.cu

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
#include <ATen/ATen.h>
3+
#include <ATen/cuda/CUDAContext.h>
4+
5+
#include <THC/THC.h>
6+
#include <THC/THCDeviceUtils.cuh>
7+
8+
#include <vector>
9+
#include <iostream>
10+
11+
int const threadsPerBlock = sizeof(unsigned long long) * 8;
12+
13+
__device__ inline float devIoU(float const * const a, float const * const b) {
14+
if (a[5] != b[5]) {
15+
return 0.0;
16+
}
17+
float left = max(a[0], b[0]), right = min(a[2], b[2]);
18+
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
19+
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
20+
float interS = width * height;
21+
float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
22+
float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
23+
return interS / (Sa + Sb - interS);
24+
}
25+
26+
__global__ void ml_nms_kernel(const int n_boxes, const float nms_overlap_thresh,
27+
const float *dev_boxes, unsigned long long *dev_mask) {
28+
const int row_start = blockIdx.y;
29+
const int col_start = blockIdx.x;
30+
31+
// if (row_start > col_start) return;
32+
33+
const int row_size =
34+
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
35+
const int col_size =
36+
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
37+
38+
__shared__ float block_boxes[threadsPerBlock * 6];
39+
if (threadIdx.x < col_size) {
40+
block_boxes[threadIdx.x * 6 + 0] =
41+
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 0];
42+
block_boxes[threadIdx.x * 6 + 1] =
43+
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 1];
44+
block_boxes[threadIdx.x * 6 + 2] =
45+
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 2];
46+
block_boxes[threadIdx.x * 6 + 3] =
47+
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 3];
48+
block_boxes[threadIdx.x * 6 + 4] =
49+
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 4];
50+
block_boxes[threadIdx.x * 6 + 5] =
51+
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 5];
52+
}
53+
__syncthreads();
54+
55+
if (threadIdx.x < row_size) {
56+
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
57+
const float *cur_box = dev_boxes + cur_box_idx * 6;
58+
int i = 0;
59+
unsigned long long t = 0;
60+
int start = 0;
61+
if (row_start == col_start) {
62+
start = threadIdx.x + 1;
63+
}
64+
for (i = start; i < col_size; i++) {
65+
if (devIoU(cur_box, block_boxes + i * 6) > nms_overlap_thresh) {
66+
t |= 1ULL << i;
67+
}
68+
}
69+
const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
70+
dev_mask[cur_box_idx * col_blocks + col_start] = t;
71+
}
72+
}
73+
74+
// boxes is a N x 6 tensor
75+
at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
76+
using scalar_t = float;
77+
AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
78+
auto scores = boxes.select(1, 4);
79+
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
80+
auto boxes_sorted = boxes.index_select(0, order_t);
81+
82+
int boxes_num = boxes.size(0);
83+
84+
const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
85+
86+
scalar_t* boxes_dev = boxes_sorted.data<scalar_t>();
87+
88+
THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
89+
90+
unsigned long long* mask_dev = NULL;
91+
//THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
92+
// boxes_num * col_blocks * sizeof(unsigned long long)));
93+
94+
mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
95+
96+
dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
97+
THCCeilDiv(boxes_num, threadsPerBlock));
98+
dim3 threads(threadsPerBlock);
99+
ml_nms_kernel<<<blocks, threads>>>(boxes_num,
100+
nms_overlap_thresh,
101+
boxes_dev,
102+
mask_dev);
103+
104+
std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
105+
THCudaCheck(cudaMemcpy(&mask_host[0],
106+
mask_dev,
107+
sizeof(unsigned long long) * boxes_num * col_blocks,
108+
cudaMemcpyDeviceToHost));
109+
110+
std::vector<unsigned long long> remv(col_blocks);
111+
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
112+
113+
at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
114+
int64_t* keep_out = keep.data<int64_t>();
115+
116+
int num_to_keep = 0;
117+
for (int i = 0; i < boxes_num; i++) {
118+
int nblock = i / threadsPerBlock;
119+
int inblock = i % threadsPerBlock;
120+
121+
if (!(remv[nblock] & (1ULL << inblock))) {
122+
keep_out[num_to_keep++] = i;
123+
unsigned long long *p = &mask_host[0] + i * col_blocks;
124+
for (int j = nblock; j < col_blocks; j++) {
125+
remv[j] |= p[j];
126+
}
127+
}
128+
}
129+
130+
THCudaFree(state, mask_dev);
131+
// TODO improve this part
132+
return std::get<0>(order_t.index({
133+
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
134+
order_t.device(), keep.scalar_type())
135+
}).sort(0, false));
136+
}

fcos_core/csrc/cuda/vision.h

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
5656
const int width);
5757

5858
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
59+
at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
5960

6061
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
6162
at::Tensor offset, at::Tensor output,

fcos_core/csrc/ml_nms.h

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
#pragma once
3+
#include "cpu/vision.h"
4+
5+
#ifdef WITH_CUDA
6+
#include "cuda/vision.h"
7+
#endif
8+
9+
10+
at::Tensor ml_nms(const at::Tensor& dets,
11+
const at::Tensor& scores,
12+
const at::Tensor& labels,
13+
const float threshold) {
14+
15+
if (dets.type().is_cuda()) {
16+
#ifdef WITH_CUDA
17+
// TODO raise error if not compiled with CUDA
18+
if (dets.numel() == 0)
19+
return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
20+
auto b = at::cat({dets, scores.unsqueeze(1), labels.unsqueeze(1)}, 1);
21+
return ml_nms_cuda(b, threshold);
22+
#else
23+
AT_ERROR("Not compiled with GPU support");
24+
#endif
25+
}
26+
AT_ERROR("CPU version not implemented");
27+
}

fcos_core/csrc/vision.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
22
#include "nms.h"
3+
#include "ml_nms.h"
34
#include "ROIAlign.h"
45
#include "ROIPool.h"
56
#include "SigmoidFocalLoss.h"
@@ -8,6 +9,7 @@
89

910
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1011
m.def("nms", &nms, "non-maximum suppression");
12+
m.def("ml_nms", &ml_nms, "multi-label non-maximum suppression");
1113
m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
1214
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
1315
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");

fcos_core/layers/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .misc import ConvTranspose2d
88
from .misc import BatchNorm2d
99
from .misc import interpolate
10-
from .nms import nms
10+
from .nms import nms, ml_nms
1111
from .roi_align import ROIAlign
1212
from .roi_align import roi_align
1313
from .roi_pool import ROIPool
@@ -26,6 +26,7 @@
2626

2727
__all__ = [
2828
"nms",
29+
"ml_nms",
2930
"roi_align",
3031
"ROIAlign",
3132
"roi_pool",

fcos_core/layers/nms.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from fcos_core import _C
44

55
nms = _C.nms
6+
ml_nms = _C.ml_nms
67
# nms.__doc__ = """
78
# This function performs Non-maximum suppresion"""

fcos_core/modeling/rpn/fcos/inference.py

+3-27
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fcos_core.modeling.utils import cat
88
from fcos_core.structures.bounding_box import BoxList
99
from fcos_core.structures.boxlist_ops import cat_boxlist
10-
from fcos_core.structures.boxlist_ops import boxlist_nms
10+
from fcos_core.structures.boxlist_ops import boxlist_ml_nms
1111
from fcos_core.structures.boxlist_ops import remove_small_boxes
1212

1313

@@ -146,32 +146,8 @@ def select_over_all_levels(self, boxlists):
146146
num_images = len(boxlists)
147147
results = []
148148
for i in range(num_images):
149-
scores = boxlists[i].get_field("scores")
150-
labels = boxlists[i].get_field("labels")
151-
boxes = boxlists[i].bbox
152-
boxlist = boxlists[i]
153-
result = []
154-
# skip the background
155-
for j in range(1, self.num_classes):
156-
inds = (labels == j).nonzero().view(-1)
157-
158-
scores_j = scores[inds]
159-
boxes_j = boxes[inds, :].view(-1, 4)
160-
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
161-
boxlist_for_class.add_field("scores", scores_j)
162-
boxlist_for_class = boxlist_nms(
163-
boxlist_for_class, self.nms_thresh,
164-
score_field="scores"
165-
)
166-
num_labels = len(boxlist_for_class)
167-
boxlist_for_class.add_field(
168-
"labels", torch.full((num_labels,), j,
169-
dtype=torch.int64,
170-
device=scores.device)
171-
)
172-
result.append(boxlist_for_class)
173-
174-
result = cat_boxlist(result)
149+
# multiclass nms
150+
result = boxlist_ml_nms(boxlists[i], self.nms_thresh)
175151
number_of_detections = len(result)
176152

177153
# Limit to max_per_image detections **over all classes**

fcos_core/structures/boxlist_ops.py

+28
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .bounding_box import BoxList
55

66
from fcos_core.layers import nms as _box_nms
7+
from fcos_core.layers import ml_nms as _box_ml_nms
78

89

910
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
@@ -31,6 +32,33 @@ def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
3132
return boxlist.convert(mode)
3233

3334

35+
def boxlist_ml_nms(boxlist, nms_thresh, max_proposals=-1,
36+
score_field="scores", label_field="labels"):
37+
"""
38+
Performs non-maximum suppression on a boxlist, with scores specified
39+
in a boxlist field via score_field.
40+
41+
Arguments:
42+
boxlist(BoxList)
43+
nms_thresh (float)
44+
max_proposals (int): if > 0, then only the top max_proposals are kept
45+
after non-maximum suppression
46+
score_field (str)
47+
"""
48+
if nms_thresh <= 0:
49+
return boxlist
50+
mode = boxlist.mode
51+
boxlist = boxlist.convert("xyxy")
52+
boxes = boxlist.bbox
53+
scores = boxlist.get_field(score_field)
54+
labels = boxlist.get_field(label_field)
55+
keep = _box_ml_nms(boxes, scores, labels.float(), nms_thresh)
56+
if max_proposals > 0:
57+
keep = keep[: max_proposals]
58+
boxlist = boxlist[keep]
59+
return boxlist.convert(mode)
60+
61+
3462
def remove_small_boxes(boxlist, min_size):
3563
"""
3664
Only keep boxes with both sides >= min_size

0 commit comments

Comments
 (0)