怎么理解nms?
非极大值抑制,简单的说就是给出一大堆bbox和相应的得分,对于其中区域重合的box,如果两个box重合部分大于设定的theshold,就抛弃小的那个,直到所有的box
都判定完了。
struct anchor_box
{
float x1;
float y1;
float x2;
float y2;
};
struct FaceDetectInfo
{
float score;
anchor_box rect;
FacePts pts; // 这个标记点在nms中没啥用
};
std::vector<FaceDetectInfo> RetinaFace::nms(std::vector<FaceDetectInfo>& bboxes, float threshold)
{
std::vector<FaceDetectInfo> bboxes_nms;
std::sort(bboxes.begin(), bboxes.end(), CompareBBox);
int32_t select_idx = 0;
int32_t num_bbox = static_cast<int32_t>(bboxes.size());
std::vector<int32_t> mask_merged(num_bbox, 0);
bool all_merged = false;
while (!all_merged) {
while (select_idx < num_bbox && mask_merged[select_idx] == 1)
select_idx++;
if (select_idx == num_bbox) {
all_merged = true;
continue;
}
bboxes_nms.push_back(bboxes[select_idx]);
mask_merged[select_idx] = 1;
anchor_box select_bbox = bboxes[select_idx].rect;
float area1 = static_cast<float>((select_bbox.x2 - select_bbox.x1 + 1) * (select_bbox.y2 - select_bbox.y1 + 1));
float x1 = static_cast<float>(select_bbox.x1);
float y1 = static_cast<float>(select_bbox.y1);
float x2 = static_cast<float>(select_bbox.x2);
float y2 = static_cast<float>(select_bbox.y2);
select_idx++;
for (int32_t i = select_idx; i < num_bbox; i++) {
if (mask_merged[i] == 1)
continue;
anchor_box& bbox_i = bboxes[i].rect;
float x = std::max<float>(x1, static_cast<float>(bbox_i.x1));
float y = std::max<float>(y1, static_cast<float>(bbox_i.y1));
float w = std::min<float>(x2, static_cast<float>(bbox_i.x2)) - x + 1; //<- float 型不加1
float h = std::min<float>(y2, static_cast<float>(bbox_i.y2)) - y + 1;
if (w <= 0 || h <= 0)
continue;
float area2 = static_cast<float>((bbox_i.x2 - bbox_i.x1 + 1) * (bbox_i.y2 - bbox_i.y1 + 1));
float area_intersect = w * h;
if (static_cast<float>(area_intersect) / (area1 + area2 - area_intersect) > threshold) {
mask_merged[i] = 1;
}
}
}
return bboxes_nms;
}
这段代码来自retinaface mnet tensorrt实现中的一个实现,具体地址我忘了。我觉得这段代码可优化空间很大。
代码思路很简单。写段伪代码描述下
input: vector<bbox> bboxs, float threshold, bool all_merged = false
whie (!all_merged):
select it in bboxs where it not merge, if all bboxs have merged, set all_merged true,then break loop.
for other in bboxs:
if 'other' not merge and IOU('it', 'other') > threshold:
set 'other' merged
else continue