本blog為github上CharlesShang/TFFRCNN版源碼解析系列代碼筆記
---------------個人學習筆記---------------
----------------本文作者疆--------------
------點擊此處鏈接至博客園原文------
1.主函數調用函數執行順序:
parse_args()解析運行參數(如--gpu 0 ...)--->get_network(args.demo_net)加載網絡(factory.py中)得到net
--->tf內部機制創建sess和恢復網絡模型等--->glob.glob('圖像地址')返回im_names地址列表(glob.py中)--->逐張圖像
循環調用demo(sess,net,im_name)
2.parse_args()函數返回args
parser = argparse.ArgumentParser(description='Faster R-CNN demo') # 新建一個解析對象 parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',default=0, type=int) # 含默認值 ... args = parser.parse_args() # 類內同名函數

# -*- coding:utf-8 -*- # Author: WUJiang # argparse模塊功能測試 import argparse parser = argparse.ArgumentParser(description="test") parser.add_argument('--mode', dest='work_mode', default=0) # 別名、默認值 parser.add_argument('--day', dest='date', default=4) args = parser.parse_args() print(args) # Namespace(day=4, mode=0) # args.date或args.day為4 args.work_mode或args.mode為0
3.demo()函數的執行邏輯
demo(sess, net, image_name)--->讀取圖像調用im_detect(sess, net, im)返回scores和boxes(應注意其維度,R為boxes個數)(test.py中)其中耗時僅統計了im_detect函數耗時,未統計nms等處理耗時
--->設置CONF_THRESH得分閾值和NMS_THRESH閾值--->針對各個類別,構造該類的dets(R*5,R表示R個box,5=4坐標+1得分)
--->該類內執行nms(nms_wrapper.py中)(IoU閾值為0.3)--->vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)對該類檢測結果進行繪制,需用
到得分閾值(demo.py中)
def im_detect(sess, net, im, boxes=None): """Detect object classes in an image given object proposals. Arguments: net (caffe.Net): Fast R-CNN network to use im (ndarray): color image to test (in BGR order) boxes (ndarray): R x 4 array of object proposals Returns: scores (ndarray): R x K array of object class scores (K includes background as object category 0) boxes (ndarray): R x (4*K) array of predicted bounding boxes """
# 針對每個boxes,得到其屬於各類的得分及按各類得到的回歸boxes,若得分閾值設置較低,會看到圖像某個目標被檢測出多類超過閾值得分的box盒
CONF_THRESH = 0.8 NMS_THRESH = 0.3 for cls_ind, cls in enumerate(CLASSES[1:]): cls_ind += 1 # because we skipped background cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)] # R*4 cls_scores = scores[:, cls_ind] # R*1 dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) # R*5 keep = nms(dets, NMS_THRESH) dets = dets[keep, :] vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)

# -*- coding:utf-8 -*- # Author: WUJiang # newaxis作用 import numpy as np a = np.array([1, 2, 3, 4, 5]) b = a[np.newaxis, :] c = a[:, np.newaxis] # [[1 2 3 4 5]] print(b) """ [[1] [2] [3] [4] [5]] """ print(c)

# -*- coding:utf-8 -*- # Author: WUJiang # 數組拼接 import numpy as np a = np.array([ [2, 2], [0, 3] ]) b = np.array([ [4, 1], [3, 1] ]) """ [[2 2 4 1] [0 3 3 1]] """ print(np.hstack((a, b))) """ [[2 2] [0 3] [4 1] [3 1]] """ print(np.vstack((a, b)))
4.demo()中圖像目標檢測時間的獲取(Timer是在lib.utils.timer中定義的類)
from lib.utils.timer import Timer timer = Timer() timer.tic() scores, boxes = im_detect(sess, net, im) timer.toc()
實際上每次重新實例化timer,因此計算的時間即是t2-t1(簡單地獲取當時時間戳time.time()),而不是多張圖像的平均檢測時間

import time class Timer(object): """A simple timer.""" def __init__(self): self.total_time = 0. self.calls = 0 self.start_time = 0. self.diff = 0. self.average_time = 0. def tic(self): # using time.time instead of time.clock because time time.clock # does not normalize for multithreading self.start_time = time.time() def toc(self, average=True): self.diff = time.time() - self.start_time self.total_time += self.diff self.calls += 1 self.average_time = self.total_time / self.calls if average: return self.average_time else: return self.diff
5.vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)函數(暫時忽略matplotlib.pyplot即plt模塊相關繪制功能)的執行邏輯
得到dets中得分超過CONF_THRESH的索引inds,對於該類遍歷各個超過CONF_THRESH的bbox進行繪制

inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return for i in inds: bbox = dets[i, :4] score = dets[i, -1] ...

# -*- coding:utf-8 -*- # Author: WUJiang # np.where測試 import numpy as np a = np.array([ [1, 2, 4, 5, 0.8], [2, 5, 7, 9, 0.9], [3, 6, 5, 20, 0.95] ]) # (array([1, 1, 2, 2], dtype=int64), array([2, 3, 1, 3], dtype=int64)) 對應於4個數組元素位置 # <class 'tuple'> b = np.where(a > 5) # [1 1 2 2] # <class 'numpy.ndarray'> b0 = b[0] print(type(b0)) # [0.8 0.9 0.95] print(a[:, 4]) # (array([1, 2], dtype=int64),) # # <class 'tuple'> c = np.where(a[:, 4] > 0.8) # [1 2] # <class 'numpy.ndarray'> print(c[0])