Pytorch-Faster-RCNN 中的 MAP 實現 (解析imdb.py 和 pascal_voc.py)


---恢復內容開始---

MAP是衡量object dectection算法的重要criteria,然而一直沒有仔細閱讀相關代碼,今天就好好看一下:

1. 測試test過程是由FRCN/tools/test_net.py中調用的test_net()完成 #from model.test import test_net

test_net()定義在FRCN/lib/model/test.py (193-194行):調用了imdb.evaluate_detections

print('Evaluating detections')
imdb.evaluate_detections(all_boxes, output_dir)

imdb是從FRCN/lib/model/test.py(84行)傳入的:

imdb = get_imdb(args.imdb_name)

from datasets.factory import get_imdb,為了了解如何定義一個imdb,我們去FRCN/lib/datasets/factory.py

 1 """Factory method for easily getting imdbs by name."""
 2 from __future__ import absolute_import
 3 from __future__ import division
 4 from __future__ import print_function
 5 
 6 __sets = {}
 7 from datasets.pascal_voc import pascal_voc
 8 
 9 import numpy as np
10 
11 # Set up voc_<year>_<split> 
12 for year in ['2007', '2012']:
13   for split in ['train', 'val', 'trainval', 'test']:
14     name = 'voc_{}_{}'.format(year, split)
15     __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
16 
17 for year in ['2007', '2012']:
18   for split in ['train', 'val', 'trainval', 'test']:
19     name = 'voc_{}_{}_diff'.format(year, split)
20     __sets[name] = (lambda split=split, year=year: pascal_voc(split, year, use_diff=True))
21 
22 def get_imdb(name):
23   """Get an imdb (image database) by name."""
24   if name not in __sets:
25     raise KeyError('Unknown dataset: {}'.format(name))
26   return __sets[name]()
27 
28 def list_imdbs():
29   """List all registered imdbs."""
30   return list(__sets.keys())

coco數據集的定義同pascal_voc. 可以看到,get_imdb(args.imdb_name)將會返回的就是pascal_voc(split, year)這樣一個對象。

 

2. 來到pascal_voc.py :

  1 # --------------------------------------------------------
  2 # Fast R-CNN
  3 # Copyright (c) 2015 Microsoft
  4 # Licensed under The MIT License [see LICENSE for details]
  5 # Written by Ross Girshick and Xinlei Chen
  6 # --------------------------------------------------------
  7 from __future__ import absolute_import
  8 from __future__ import division
  9 from __future__ import print_function
 10 
 11 import os
 12 from datasets.imdb import imdb
 13 import datasets.ds_utils as ds_utils
 14 import xml.etree.ElementTree as ET
 15 import numpy as np
 16 import scipy.sparse
 17 import scipy.io as sio
 18 import pickle
 19 import subprocess
 20 import uuid
 21 from .voc_eval import voc_eval
 22 from model.config import cfg
 23 
 24 
 25 class pascal_voc(imdb):
 26   def __init__(self, image_set, year, use_diff=False):
 27     name = 'voc_' + year + '_' + image_set
 28     if use_diff:
 29       name += '_diff'
 30     imdb.__init__(self, name)
 31     self._year = year
 32     self._image_set = image_set
 33     self._devkit_path = self._get_default_path()
 34     self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
 35     self._classes = ('__background__',  # always index 0
 36                      'title', 'xlabel',  'ylabel')
 37                  ####    'text', 'ylabel')
 38                  #    'aeroplane', 'bicycle', 'bird', 'boat',
 39                  #    'bottle', 'bus', 'car', 'cat', 'chair',
 40                  #    'cow', 'diningtable', 'dog', 'horse',
 41                  #    'motorbike', 'person', 'pottedplant',
 42                  #    'sheep', 'sofa', 'train', 'tvmonitor')
 43     self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
 44     self._image_ext = '.jpg'
 45     self._image_index = self._load_image_set_index()
 46     # Default to roidb handler
 47     self._roidb_handler = self.gt_roidb
 48     self._salt = str(uuid.uuid4())
 49     self._comp_id = 'comp4'
 50 
 51     # PASCAL specific config options
 52     self.config = {'cleanup': True,
 53                    'use_salt': True,
 54                    'use_diff': use_diff,
 55                    'matlab_eval': False,
 56                    'rpn_file': None}
 57 
 58     assert os.path.exists(self._devkit_path), \
 59       'VOCdevkit path does not exist: {}'.format(self._devkit_path)
 60     assert os.path.exists(self._data_path), \
 61       'Path does not exist: {}'.format(self._data_path)
 62 
 63   def image_path_at(self, i):
 64     """
 65     Return the absolute path to image i in the image sequence.
 66     """
 67     return self.image_path_from_index(self._image_index[i])
 68 
 69   def image_path_from_index(self, index):
 70     """
 71     Construct an image path from the image's "index" identifier.
 72     """
 73     image_path = os.path.join(self._data_path, 'JPEGImages',
 74                               index + self._image_ext)
 75     assert os.path.exists(image_path), \
 76       'Path does not exist: {}'.format(image_path)
 77     return image_path
 78 
 79   def _load_image_set_index(self):
 80     """
 81     Load the indexes listed in this dataset's image set file.
 82     """
 83     # Example path to image set file:
 84     # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
 85     image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
 86                                   self._image_set + '.txt')
 87     assert os.path.exists(image_set_file), \
 88       'Path does not exist: {}'.format(image_set_file)
 89     with open(image_set_file) as f:
 90       image_index = [x.strip() for x in f.readlines()]
 91     return image_index
 92 
 93   def _get_default_path(self):
 94     """
 95     Return the default path where PASCAL VOC is expected to be installed.
 96     """
 97     return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)
 98 
 99   def gt_roidb(self):
100     """
101     Return the database of ground-truth regions of interest.
102 
103     This function loads/saves from/to a cache file to speed up future calls.
104     """
105     cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
106     if os.path.exists(cache_file):
107       with open(cache_file, 'rb') as fid:
108         try:
109           roidb = pickle.load(fid)
110         except:
111           roidb = pickle.load(fid, encoding='bytes')
112       print('{} gt roidb loaded from {}'.format(self.name, cache_file))
113       return roidb
114 
115     gt_roidb = [self._load_pascal_annotation(index)
116                 for index in self.image_index]
117     with open(cache_file, 'wb') as fid:
118       pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
119     print('wrote gt roidb to {}'.format(cache_file))
120 
121     return gt_roidb
122 
123   def rpn_roidb(self):
124     if int(self._year) == 2007 or self._image_set != 'test':
125       gt_roidb = self.gt_roidb()
126       rpn_roidb = self._load_rpn_roidb(gt_roidb)
127       roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
128     else:
129       roidb = self._load_rpn_roidb(None)
130 
131     return roidb
132 
133   def _load_rpn_roidb(self, gt_roidb):
134     filename = self.config['rpn_file']
135     print('loading {}'.format(filename))
136     assert os.path.exists(filename), \
137       'rpn data not found at: {}'.format(filename)
138     with open(filename, 'rb') as f:
139       box_list = pickle.load(f)
140     return self.create_roidb_from_box_list(box_list, gt_roidb)
141 
142   def _load_pascal_annotation(self, index):
143     """
144     Load image and bounding boxes info from XML file in the PASCAL VOC
145     format.
146     """
147     filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
148     tree = ET.parse(filename)
149     objs = tree.findall('object')
150     if not self.config['use_diff']:
151       # Exclude the samples labeled as difficult
152       non_diff_objs = [
153         obj for obj in objs if int(obj.find('difficult').text) == 0]
154       # if len(non_diff_objs) != len(objs):
155       #     print 'Removed {} difficult objects'.format(
156       #         len(objs) - len(non_diff_objs))
157       objs = non_diff_objs
158     num_objs = len(objs)
159 
160     boxes = np.zeros((num_objs, 4), dtype=np.uint16)
161     gt_classes = np.zeros((num_objs), dtype=np.int32)
162     overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
163     # "Seg" area for pascal is just the box area
164     seg_areas = np.zeros((num_objs), dtype=np.float32)
165 
166     # Load object bounding boxes into a data frame.
167     for ix, obj in enumerate(objs):
168       bbox = obj.find('bndbox')
169       # Make pixel indexes 0-based
170       x1 = float(bbox.find('xmin').text) - 1
171       y1 = float(bbox.find('ymin').text) - 1
172       x2 = float(bbox.find('xmax').text) - 1
173       y2 = float(bbox.find('ymax').text) - 1
174       cls = self._class_to_ind[obj.find('name').text.lower().strip()]
175       boxes[ix, :] = [x1, y1, x2, y2]
176       gt_classes[ix] = cls
177       overlaps[ix, cls] = 1.0
178       seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
179 
180     overlaps = scipy.sparse.csr_matrix(overlaps)
181 
182     return {'boxes': boxes,
183             'gt_classes': gt_classes,
184             'gt_overlaps': overlaps,
185             'flipped': False,
186             'seg_areas': seg_areas}
187 
188   def _get_comp_id(self):
189     comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
190                else self._comp_id)
191     return comp_id
192 
193   def _get_voc_results_file_template(self):
194     # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
195     filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
196     path = os.path.join(
197       self._devkit_path,
198       'results',
199       'VOC' + self._year,
200       'Main',
201       filename)
202     return path
203 
204   def _write_voc_results_file(self, all_boxes):
205     for cls_ind, cls in enumerate(self.classes):
206       if cls == '__background__':
207         continue
208       print('Writing {} VOC results file'.format(cls))
209       filename = self._get_voc_results_file_template().format(cls)
210       with open(filename, 'wt') as f:
211         for im_ind, index in enumerate(self.image_index):
212           dets = all_boxes[cls_ind][im_ind]
213           if dets == []:
214             continue
215           # the VOCdevkit expects 1-based indices
216           for k in range(dets.shape[0]):
217             f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
218                     format(index, dets[k, -1],
219                            dets[k, 0] + 1, dets[k, 1] + 1,
220                            dets[k, 2] + 1, dets[k, 3] + 1))
221 
222   def _do_python_eval(self, output_dir='output'):
223     annopath = os.path.join(
224       self._devkit_path,
225       'VOC' + self._year,
226       'Annotations',
227       '{:s}.xml')
228     imagesetfile = os.path.join(
229       self._devkit_path,
230       'VOC' + self._year,
231       'ImageSets',
232       'Main',
233       self._image_set + '.txt')
234     cachedir = os.path.join(self._devkit_path, 'annotations_cache')
235     aps = []
236     # The PASCAL VOC metric changed in 2010
237     use_07_metric = True if int(self._year) < 2010 else False
238     print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
239     if not os.path.isdir(output_dir):
240       os.mkdir(output_dir)
241     for i, cls in enumerate(self._classes):
242       if cls == '__background__':
243         continue
244       filename = self._get_voc_results_file_template().format(cls)
245       rec, prec, ap = voc_eval(
246         filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
247         use_07_metric=use_07_metric, use_diff=self.config['use_diff'])
248       aps += [ap]
249       print(('AP for {} = {:.4f}'.format(cls, ap)))
250       with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
251         pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
252     print(('Mean AP = {:.4f}'.format(np.mean(aps))))
253     print('~~~~~~~~')
254     print('Results:')
255     for ap in aps:
256       print(('{:.3f}'.format(ap)))
257     print(('{:.3f}'.format(np.mean(aps))))
258     print('~~~~~~~~')
259     print('')
260     print('--------------------------------------------------------------')
261     print('Results computed with the **unofficial** Python eval code.')
262     print('Results should be very close to the official MATLAB eval code.')
263     print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
264     print('-- Thanks, The Management')
265     print('--------------------------------------------------------------')
266 
267   def _do_matlab_eval(self, output_dir='output'):
268     print('-----------------------------------------------------')
269     print('Computing results with the official MATLAB eval code.')
270     print('-----------------------------------------------------')
271     path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
272                         'VOCdevkit-matlab-wrapper')
273     cmd = 'cd {} && '.format(path)
274     cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
275     cmd += '-r "dbstop if error; '
276     cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \
277       .format(self._devkit_path, self._get_comp_id(),
278               self._image_set, output_dir)
279     print(('Running:\n{}'.format(cmd)))
280     status = subprocess.call(cmd, shell=True)
281 
282   def evaluate_detections(self, all_boxes, output_dir):
283     self._write_voc_results_file(all_boxes)
284     self._do_python_eval(output_dir)
285     if self.config['matlab_eval']:
286       self._do_matlab_eval(output_dir)
287     if self.config['cleanup']:
288       for cls in self._classes:
289         if cls == '__background__':
290           continue
291         filename = self._get_voc_results_file_template().format(cls)
292         os.remove(filename)
293 
294   def competition_mode(self, on):
295     if on:
296       self.config['use_salt'] = False
297       self.config['cleanup'] = False
298     else:
299       self.config['use_salt'] = True
300       self.config['cleanup'] = True
301 
302 
303 if __name__ == '__main__':
304   from datasets.pascal_voc import pascal_voc
305 
306   d = pascal_voc('trainval', '2007')
307   res = d.roidb
308   from IPython import embed;
309 
310   embed()

我們先看涉及到MAP的方法,其他方法暫時放下。

這里通過evaluate_detections方法調用了_do_python_eval方法,后者通過調用voc_eval函數進行了AP和MAP的計算(245-247行)。

  1 # --------------------------------------------------------
  2 # Fast/er R-CNN
  3 # Licensed under The MIT License [see LICENSE for details]
  4 # Written by Bharath Hariharan
  5 # --------------------------------------------------------
  6 from __future__ import absolute_import
  7 from __future__ import division
  8 from __future__ import print_function
  9 
 10 import xml.etree.ElementTree as ET
 11 import os
 12 import pickle
 13 import numpy as np
 14 
 15 def parse_rec(filename):
 16   """ Parse a PASCAL VOC xml file """
 17   tree = ET.parse(filename)
 18   objects = []
 19   for obj in tree.findall('object'):
 20     obj_struct = {}
 21     obj_struct['name'] = obj.find('name').text
 22     obj_struct['pose'] = obj.find('pose').text
 23     obj_struct['truncated'] = int(obj.find('truncated').text)
 24     obj_struct['difficult'] = int(obj.find('difficult').text)
 25     bbox = obj.find('bndbox')
 26     obj_struct['bbox'] = [int(float(bbox.find('xmin').text)),
 27                           int(float(bbox.find('ymin').text)),
 28                           int(float(bbox.find('xmax').text)),
 29                           int(float(bbox.find('ymax').text))]
 30     objects.append(obj_struct)
 31 
 32   return objects
 33 
 34 
 35 def voc_ap(rec, prec, use_07_metric=False):
 36   """ ap = voc_ap(rec, prec, [use_07_metric])
 37   Compute VOC AP given precision and recall.
 38   If use_07_metric is true, uses the
 39   VOC 07 11 point method (default:False).
 40   """
 41   if use_07_metric:
 42     # 11 point metric
 43     ap = 0.
 44     for t in np.arange(0., 1.1, 0.1):
 45       if np.sum(rec >= t) == 0:
 46         p = 0
 47       else:
 48         p = np.max(prec[rec >= t])
 49       ap = ap + p / 11.
 50   else:
 51     # correct AP calculation
 52     # first append sentinel values at the end
 53     mrec = np.concatenate(([0.], rec, [1.]))
 54     mpre = np.concatenate(([0.], prec, [0.]))
 55 
 56     # compute the precision envelope
 57     for i in range(mpre.size - 1, 0, -1):
 58       mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
 59 
 60     # to calculate area under PR curve, look for points
 61     # where X axis (recall) changes value
 62     i = np.where(mrec[1:] != mrec[:-1])[0]
 63 
 64     # and sum (\Delta recall) * prec
 65     ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
 66   return ap
 67 
 68 
 69 def voc_eval(detpath,
 70              annopath,
 71              imagesetfile,
 72              classname,
 73              cachedir,
 74              ovthresh=0.5,
 75              use_07_metric=False,
 76              use_diff=False):
 77   """rec, prec, ap = voc_eval(detpath,
 78                               annopath,
 79                               imagesetfile,
 80                               classname,
 81                               [ovthresh],
 82                               [use_07_metric])
 83 
 84   Top level function that does the PASCAL VOC evaluation.
 85 
 86   detpath: Path to detections
 87       detpath.format(classname) should produce the detection results file.
 88   annopath: Path to annotations
 89       annopath.format(imagename) should be the xml annotations file.
 90   imagesetfile: Text file containing the list of images, one image per line.
 91   classname: Category name (duh)
 92   cachedir: Directory for caching the annotations
 93   [ovthresh]: Overlap threshold (default = 0.5)
 94   [use_07_metric]: Whether to use VOC07's 11 point AP computation
 95       (default False)
 96   """
 97   # assumes detections are in detpath.format(classname)
 98   # assumes annotations are in annopath.format(imagename)
 99   # assumes imagesetfile is a text file with each line an image name
100   # cachedir caches the annotations in a pickle file
101 
102   # first load gt
103   if not os.path.isdir(cachedir):
104     os.mkdir(cachedir)
105   cachefile = os.path.join(cachedir, '%s_annots.pkl' % imagesetfile)
106   # read list of images
107   with open(imagesetfile, 'r') as f:
108     lines = f.readlines()
109   imagenames = [x.strip() for x in lines]    #test.txt中的所有標號
110 
111   # load annotations
112   if not os.path.isfile(cachefile):    
113     recs = {}
114     for i, imagename in enumerate(imagenames):
115       recs[imagename] = parse_rec(annopath.format(imagename))
116       if i % 100 == 0:
117         print('Reading annotation for {:d}/{:d}'.format(
118           i + 1, len(imagenames)))
119     # save
120     print('Saving cached annotations to {:s}'.format(cachefile))
121     with open(cachefile, 'wb') as f:
122       pickle.dump(recs, f)
123   else:
124     # load
125     with open(cachefile, 'rb') as f:
126       try:
127         recs = pickle.load(f)
128       except:
129         recs = pickle.load(f, encoding='bytes')
130 
131   # extract gt objects for this class
132   class_recs = {}
133   npos = 0
134   for imagename in imagenames:
135     R = [obj for obj in recs[imagename] if obj['name'] == classname]
136     bbox = np.array([x['bbox'] for x in R])
137     if use_diff:
138       difficult = np.array([False for x in R]).astype(np.bool)      
139     else:
140       difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
141     det = [False] * len(R)
142     npos = npos + sum(~difficult)
143     class_recs[imagename] = {'bbox': bbox,
144                              'difficult': difficult,
145                              'det': det}
146 
147   # read dets
148   detfile = detpath.format(classname)
149   with open(detfile, 'r') as f:
150     lines = f.readlines()
151 
152   splitlines = [x.strip().split(' ') for x in lines]
153   image_ids = [x[0] for x in splitlines]
154   confidence = np.array([float(x[1]) for x in splitlines])
155   BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
156 
157   nd = len(image_ids)
158   tp = np.zeros(nd)
159   fp = np.zeros(nd)
160 
161   if BB.shape[0] > 0:
162     # sort by confidence
163     sorted_ind = np.argsort(-confidence)
164     sorted_scores = np.sort(-confidence)
165     BB = BB[sorted_ind, :]
166     image_ids = [image_ids[x] for x in sorted_ind]
167 
168     # go down dets and mark TPs and FPs
169     for d in range(nd):
170       R = class_recs[image_ids[d]]
171       bb = BB[d, :].astype(float)
172       ovmax = -np.inf
173       BBGT = R['bbox'].astype(float)
174 
175       if BBGT.size > 0:
176         # compute overlaps
177         # intersection
178         ixmin = np.maximum(BBGT[:, 0], bb[0])
179         iymin = np.maximum(BBGT[:, 1], bb[1])
180         ixmax = np.minimum(BBGT[:, 2], bb[2])
181         iymax = np.minimum(BBGT[:, 3], bb[3])
182         iw = np.maximum(ixmax - ixmin + 1., 0.)
183         ih = np.maximum(iymax - iymin + 1., 0.)
184         inters = iw * ih
185 
186         # union
187         uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
188                (BBGT[:, 2] - BBGT[:, 0] + 1.) *
189                (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
190 
191         overlaps = inters / uni
192         ovmax = np.max(overlaps)
193         jmax = np.argmax(overlaps)
194 
195       if ovmax > ovthresh:
196         if not R['difficult'][jmax]:
197           if not R['det'][jmax]:
198             tp[d] = 1.
199             R['det'][jmax] = 1
200           else:
201             fp[d] = 1.
202       else:
203         fp[d] = 1.
204 
205   # compute precision recall
206   fp = np.cumsum(fp)
207   tp = np.cumsum(tp)
208   rec = tp / float(npos)
209   # avoid divide by zero in case the first detection matches a difficult
210   # ground truth
211   prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
212   ap = voc_ap(rec, prec, use_07_metric)
213 
214   return rec, prec, ap

 

voc_eval(filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5, use_07_metric=use_07_metric, use_diff=self.config['use_diff'])

def voc_eval(detpath, annopath, imagesetfile, classname, cachedir, ovthresh=0.5, use_07_metric=False, use_diff=False):

filename: detpath: Path to detections 存儲detection結果的pkl文件地址

annopath: 存儲Annotations的地址

imagesetfile: 圖片集的txt文檔

classname: 當前的class

cachedir: 存儲Annotations的pkl所在目錄(可能不存在)

ovthresh=0.5: IoU的threshold,默認為0.5

use_07_metric=Flase: 是否使用2007PASCAL_VOC的MAP計算規則

use_diff=False: 是否考慮difficult的檢測樣本

 

經過一番數據處理,得到了:

BB: 當前class的所有proposal bbox (predicted)

image_ids: 當前imageset的所有image序號

class_recs: image所包含的當前class的bbox (GT)

 

 1   if BB.shape[0] > 0:
 2     # sort by confidence
 3     #'''
 4     sorted_ind = np.argsort(-confidence)
 5     sorted_scores = np.sort(-confidence)
 6     BB = BB[sorted_ind, :]    # 現在的BB是按照conf降序排列的所有predicted bbox
 7     image_ids = [image_ids[x] for x in sorted_ind]    # image_id 是BB每組bbox所屬於的image的序號
 8     
 9     #'''
10      
11     # go down dets and mark TPs and FPs
12     for d in range(nd):              #對所有proposal bbox 遍歷
13       R = class_recs[image_ids[d]]   # 找到當前bbox對應的image
14       bb = BB[d, :].astype(float)    # bb 為當前proposal bbox的坐標
15       ovmax = -np.inf                # 設置np極小值
16       BBGT = R['bbox'].astype(float) 
17 
18       if BBGT.size > 0:
19         # compute overlaps
20         # intersection
21         ixmin = np.maximum(BBGT[:, 0], bb[0])
22         iymin = np.maximum(BBGT[:, 1], bb[1])
23         ixmax = np.minimum(BBGT[:, 2], bb[2])
24         iymax = np.minimum(BBGT[:, 3], bb[3])
25         iw = np.maximum(ixmax - ixmin + 1., 0.)
26         ih = np.maximum(iymax - iymin + 1., 0.)
27         inters = iw * ih
28 
29         # union
30         uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
31                (BBGT[:, 2] - BBGT[:, 0] + 1.) *
32                (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
33 
34         overlaps = inters / uni
35         ovmax = np.max(overlaps)
36         jmax = np.argmax(overlaps)
37         print(overlaps)
38 
39       if ovmax > ovthresh:
40         if not R['difficult'][jmax]:    
41           if not R['det'][jmax]:        #是否已經被檢測過
42             tp[d] = 1.
43             R['det'][jmax] = 1
44           else:
45             fp[d] = 1.
46       else:
47         fp[d] = 1.

 

疑惑:

這里的Recall計算(voc_eval.py 208行)使用了:

rec = tp / float(npos),npos實際上是所有bbox-GT的數量,並不應該等於tp+fn吧?當且僅當:fn(包含但未被檢測出bbox的image數量)==npos-tp(未被檢測出的bbox數量)

 

ref: 1. https://datascience.stackexchange.com/questions/25119/how-to-calculate-map-for-detection-task-for-the-pascal-voc-challenge

2. http://mp.weixin.qq.com/s/FaNC9RppIhPf6T_qAz3Slg

3. https://ils.unc.edu/courses/2013_spring/inls509_001/lectures/10-EvaluationMetrics.pdf

4. https://stats.stackexchange.com/questions/260430/average-precision-in-object-detection/263758#263758


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM