Tensorflow版Faster RCNN源碼解析(TFFRCNN) (01) demo.py(含argparse模塊,numpy模塊中的newaxis、hstack、vstack和np.where等)


本blog為github上CharlesShang/TFFRCNN版源碼解析系列代碼筆記

---------------個人學習筆記---------------

----------------本文作者疆--------------

------點擊此處鏈接至博客園原文------

 

1.主函數調用函數執行順序:

parse_args()解析運行參數(如--gpu 0 ...)--->get_network(args.demo_net)加載網絡(factory.py中)得到net

--->tf內部機制創建sess和恢復網絡模型等--->glob.glob('圖像地址')返回im_names地址列表(glob.py中)--->逐張圖像

循環調用demo(sess,net,im_name)

2.parse_args()函數返回args

parser = argparse.ArgumentParser(description='Faster R-CNN demo') # 新建一個解析對象
parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',default=0, type=int)  # 含默認值
...
args = parser.parse_args() # 類內同名函數
# -*- coding:utf-8 -*-
# Author: WUJiang
# argparse模塊功能測試
import argparse

parser = argparse.ArgumentParser(description="test")
parser.add_argument('--mode', dest='work_mode', default=0)  # 別名、默認值
parser.add_argument('--day', dest='date', default=4)
args = parser.parse_args()
print(args)  # Namespace(day=4, mode=0)
# args.date或args.day為4 args.work_mode或args.mode為0
View Code

3.demo()函數的執行邏輯

demo(sess, net, image_name)--->讀取圖像調用im_detect(sess, net, im)返回scores和boxes(應注意其維度,R為boxes個數)(test.py中)其中耗時僅統計了im_detect函數耗時,未統計nms等處理耗時

--->設置CONF_THRESH得分閾值和NMS_THRESH閾值--->針對各個類別,構造該類的dets(R*5,R表示R個box,5=4坐標+1得分)

--->該類內執行nms(nms_wrapper.py中)(IoU閾值為0.3)--->vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)對該類檢測結果進行繪制,需用

到得分閾值(demo.py中)

def im_detect(sess, net, im, boxes=None):
    """Detect object classes in an image given object proposals.
    Arguments:
        net (caffe.Net): Fast R-CNN network to use
        im (ndarray): color image to test (in BGR order)
        boxes (ndarray): R x 4 array of object proposals
    Returns:
        scores (ndarray): R x K array of object class scores (K includes
            background as object category 0)
        boxes (ndarray): R x (4*K) array of predicted bounding boxes
    """
  # 針對每個boxes,得到其屬於各類的得分及按各類得到的回歸boxes,若得分閾值設置較低,會看到圖像某個目標被檢測出多類超過閾值得分的box盒
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1  # because we skipped background
        cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]    # R*4
        cls_scores = scores[:, cls_ind]         # R*1
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)  # R*5
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)
# -*- coding:utf-8 -*-
# Author: WUJiang
# newaxis作用
import numpy as np

a = np.array([1, 2, 3, 4, 5])
b = a[np.newaxis, :]
c = a[:, np.newaxis]
# [[1 2 3 4 5]]
print(b)
"""
[[1]
 [2]
 [3]
 [4]
 [5]]
 """
print(c)
View Code
# -*- coding:utf-8 -*-
# Author: WUJiang
# 數組拼接

import numpy as np

a = np.array([
    [2, 2],
    [0, 3]
])
b = np.array([
    [4, 1],
    [3, 1]
])
"""
[[2 2 4 1]
 [0 3 3 1]]
"""
print(np.hstack((a, b)))
"""
[[2 2]
 [0 3]
 [4 1]
 [3 1]]
"""
print(np.vstack((a, b)))
View Code

4.demo()中圖像目標檢測時間的獲取(Timer是在lib.utils.timer中定義的類)

from lib.utils.timer import Timer
timer = Timer()
timer.tic()
scores, boxes = im_detect(sess, net, im)
timer.toc()

實際上每次重新實例化timer,因此計算的時間即是t2-t1(簡單地獲取當時時間戳time.time()),而不是多張圖像的平均檢測時間

import time

class Timer(object):
    """A simple timer."""
    def __init__(self):
        self.total_time = 0.
        self.calls = 0
        self.start_time = 0.
        self.diff = 0.
        self.average_time = 0.

    def tic(self):
        # using time.time instead of time.clock because time time.clock
        # does not normalize for multithreading
        self.start_time = time.time()

    def toc(self, average=True):
        self.diff = time.time() - self.start_time
        self.total_time += self.diff
        self.calls += 1
        self.average_time = self.total_time / self.calls
        if average:
            return self.average_time
        else:
            return self.diff
View Code

5.vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)函數(暫時忽略matplotlib.pyplot即plt模塊相關繪制功能)的執行邏輯

得到dets中得分超過CONF_THRESH的索引inds,對於該類遍歷各個超過CONF_THRESH的bbox進行繪制

    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]
...
View Code
# -*- coding:utf-8 -*-
# Author: WUJiang
# np.where測試

import numpy as np

a = np.array([
    [1, 2, 4, 5, 0.8],
    [2, 5, 7, 9, 0.9],
    [3, 6, 5, 20, 0.95]
])
# (array([1, 1, 2, 2], dtype=int64), array([2, 3, 1, 3], dtype=int64)) 對應於4個數組元素位置
# <class 'tuple'>
b = np.where(a > 5)
# [1 1 2 2]
# <class 'numpy.ndarray'>
b0 = b[0]
print(type(b0))
# [0.8  0.9  0.95]
print(a[:, 4])
# (array([1, 2], dtype=int64),)
# # <class 'tuple'>
c = np.where(a[:, 4] > 0.8)
# [1 2]
# <class 'numpy.ndarray'>
print(c[0])
View Code


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM