怎么理解nms?
非極大值抑制,簡單的說就是給出一大堆bbox和相應的得分,對於其中區域重合的box,如果兩個box重合部分大於設定的theshold,就拋棄小的那個,直到所有的box
都判定完了。
struct anchor_box
{
float x1;
float y1;
float x2;
float y2;
};
struct FaceDetectInfo
{
float score;
anchor_box rect;
FacePts pts; // 這個標記點在nms中沒啥用
};
std::vector<FaceDetectInfo> RetinaFace::nms(std::vector<FaceDetectInfo>& bboxes, float threshold)
{
std::vector<FaceDetectInfo> bboxes_nms;
std::sort(bboxes.begin(), bboxes.end(), CompareBBox);
int32_t select_idx = 0;
int32_t num_bbox = static_cast<int32_t>(bboxes.size());
std::vector<int32_t> mask_merged(num_bbox, 0);
bool all_merged = false;
while (!all_merged) {
while (select_idx < num_bbox && mask_merged[select_idx] == 1)
select_idx++;
if (select_idx == num_bbox) {
all_merged = true;
continue;
}
bboxes_nms.push_back(bboxes[select_idx]);
mask_merged[select_idx] = 1;
anchor_box select_bbox = bboxes[select_idx].rect;
float area1 = static_cast<float>((select_bbox.x2 - select_bbox.x1 + 1) * (select_bbox.y2 - select_bbox.y1 + 1));
float x1 = static_cast<float>(select_bbox.x1);
float y1 = static_cast<float>(select_bbox.y1);
float x2 = static_cast<float>(select_bbox.x2);
float y2 = static_cast<float>(select_bbox.y2);
select_idx++;
for (int32_t i = select_idx; i < num_bbox; i++) {
if (mask_merged[i] == 1)
continue;
anchor_box& bbox_i = bboxes[i].rect;
float x = std::max<float>(x1, static_cast<float>(bbox_i.x1));
float y = std::max<float>(y1, static_cast<float>(bbox_i.y1));
float w = std::min<float>(x2, static_cast<float>(bbox_i.x2)) - x + 1; //<- float 型不加1
float h = std::min<float>(y2, static_cast<float>(bbox_i.y2)) - y + 1;
if (w <= 0 || h <= 0)
continue;
float area2 = static_cast<float>((bbox_i.x2 - bbox_i.x1 + 1) * (bbox_i.y2 - bbox_i.y1 + 1));
float area_intersect = w * h;
if (static_cast<float>(area_intersect) / (area1 + area2 - area_intersect) > threshold) {
mask_merged[i] = 1;
}
}
}
return bboxes_nms;
}
這段代碼來自retinaface mnet tensorrt實現中的一個實現,具體地址我忘了。我覺得這段代碼可優化空間很大。
代碼思路很簡單。寫段偽代碼描述下
input: vector<bbox> bboxs, float threshold, bool all_merged = false
whie (!all_merged):
select it in bboxs where it not merge, if all bboxs have merged, set all_merged true,then break loop.
for other in bboxs:
if 'other' not merge and IOU('it', 'other') > threshold:
set 'other' merged
else continue