yolo源碼解析(三)


 

 

七 測試網絡

 模型測試包含於test.py文件,Detector類的image_detector()函數用於檢測目標。

復制代碼
import os
import cv2
import argparse
import numpy as np
import tensorflow as tf
import yolo.config as cfg
from yolo.yolo_net import YOLONet
from utils.timer import Timer

'''
用於測試
'''

class Detector(object):
復制代碼

1、類初始化函數

復制代碼
 def __init__(self, net, weight_file):
        '''
        構造函數
        利用 cfg 文件對網絡參數進行初始化,
        其中 offset 的作用應該是一個定長的偏移
        boundery1和boundery2 作用是在輸出中確定每種信息的長度(如類別,置信度等)。
        其中 boundery1 指的是對於所有的 cell 的類別的預測的張量維度,所以是 self.cell_size * self.cell_size * self.num_class
        boundery2 指的是在類別之后每個cell 所對應的 bounding boxes 的數量的總和,所以是self.boundary1 + self.cell_size * self.cell_size * self.boxes_per_cell
        
        
        args:
            net:YOLONet對象
            weight_file:檢查點文件路徑
        '''
        #yolo網絡
        self.net = net
        #檢查點文件路徑
        self.weights_file = weight_file
        #輸出文件夾路徑
        self.output_dir = os.path.dirname(self.weights_file)
         #VOC 2012數據集類別名
        self.classes = cfg.CLASSES
        # #VOC 2012數據類別數
        self.num_class = len(self.classes)
        ##圖像大小
        self.image_size = cfg.IMAGE_SIZE
        #單元格大小S
        self.cell_size = cfg.CELL_SIZE
        #每個網格邊界框的個數B=2
        self.boxes_per_cell = cfg.BOXES_PER_CELL
        #閾值參數
        self.threshold = cfg.THRESHOLD
        #IoU 閾值參數
        self.iou_threshold = cfg.IOU_THRESHOLD
        '''#將網絡輸出分離為類別和置信度以及邊界框的大小,輸出維度為7*7*20 + 7*7*2 + 7*7*2*4=1470'''
         #7*7*20
        self.boundary1 = self.cell_size * self.cell_size * self.num_class
         #7*7*20+7*7*2
        self.boundary2 = self.boundary1 +\
            self.cell_size * self.cell_size * self.boxes_per_cell

        #運行圖之前,初始化變量
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())

        #恢復模型
        print('Restoring weights from: ' + self.weights_file)
        self.saver = tf.train.Saver()
        #直接載入最近保存的檢查點文件
        ckpt = tf.train.latest_checkpoint(self.output_dir)
        print("ckpt:",ckpt)         
        #如果存在檢查點文件 則恢復模型
        if ckpt!=None:
            #恢復最近的檢查點文件
            self.saver.restore(self.sess, ckpt) 
        else:
            #從指定檢查點文件恢復
            self.saver.restore(self.sess, self.weights_file)
復制代碼

2、draw_result()函數

在原始圖像上繪制邊界框,並添加一些附件信息,如目標類別,置信度。

復制代碼
    def draw_result(self, img, result):
        '''
        在原圖上繪制邊界框,以及附加信息
        
        args:
            img:原始圖片數據
            result:yolo網絡目標檢測到的邊界框,list類型 每一個元素對應一個目標框 
                  包含{類別名,x_center,y_center,w,h,置信度} 
        '''
        #遍歷每一個邊界框
        for i in range(len(result)):
            #x_center
            x = int(result[i][1])
            #y_center
            y = int(result[i][2])
            #w/2
            w = int(result[i][3] / 2)
            #h/2
            h = int(result[i][4] / 2)
            #繪制矩形框(目標邊界框) 矩形左上角,矩形右下角
            cv2.rectangle(img, (x - w, y - h), (x + w, y + h), (0, 255, 0), 2)            
            #繪制矩形框,用於存放類別名稱,使用灰度填充
            cv2.rectangle(img, (x - w, y - h - 20),
                          (x + w, y - h), (125, 125, 125), -1)
            #線型
            lineType = cv2.LINE_AA if cv2.__version__ > '3' else cv2.CV_AA
            #繪制文本信息 寫上類別名和置信度
            cv2.putText(
                img, result[i][0] + ' : %.2f' % result[i][5],
                (x - w + 5, y - h - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                (0, 0, 0), 1, lineType)
復制代碼

3、detect()函數

detect()函數用來對圖像進行目標檢測。

復制代碼
 def detect(self, img):
        '''
        圖片目標檢測
        
        args:
            img:原始圖片數據
            
        return:
            result:返回檢測到的邊界框,list類型 每一個元素對應一個目標框 
            包含{類別名,x_center,y_center,w,h,置信度}
        '''
        #獲取圖片的高和寬
        img_h, img_w, _ = img.shape
        #圖片縮放 [448,448,3]
        inputs = cv2.resize(img, (self.image_size, self.image_size))
        #BGR->RGB  uint->float32
        inputs = cv2.cvtColor(inputs, cv2.COLOR_BGR2RGB).astype(np.float32)
        #歸一化處理 [-1.0,1.0]
        inputs = (inputs / 255.0) * 2.0 - 1.0
        #reshape [1,448,448,3]
        inputs = np.reshape(inputs, (1, self.image_size, self.image_size, 3))

        #獲取網絡輸出第一項(即第一張圖片) [1,1470]
        result = self.detect_from_cvmat(inputs)[0]

        #對檢測的圖片的邊界框進行縮放處理,一張圖片可以有多個邊界框
        for i in range(len(result)):
            #x_center, y_center, w, h都是真實值,分別表示預測邊界框的中心坐標,寬和高,都是浮點型
            result[i][1] *= (1.0 * img_w / self.image_size)    #x_center
            result[i][2] *= (1.0 * img_h / self.image_size)    #y_center
            result[i][3] *= (1.0 * img_w / self.image_size)    #w
            result[i][4] *= (1.0 * img_h / self.image_size)    #h

        #<class 'list'> 6 ['person', 405.83171163286482, 161.40340532575334, 166.17623397282193, 298.85661533900668, 0.69636690616607666]
        #Average detecting time: 0.571s
        print(type(result),len(result),result[0])
        return result
復制代碼

4、detect_from_cvmat()函數

復制代碼
 def detect_from_cvmat(self, inputs):
        '''
        運行yolo網絡,開始檢測
        
        args:
            inputs:輸入數據  [None,448,448,3]
            
        return:
            results:返回目標檢測的結果,每一個元素對應一個測試圖片,每個元素包含着若干個邊界框
        
        '''
        #返回網絡最后一層,激活函數處理之前的值  形狀[None,1470]
        net_output = self.sess.run(self.net.logits,
                                   feed_dict={self.net.images: inputs})
        results = []
        
        #對網絡輸出每一行數據進行處理
        for i in range(net_output.shape[0]):
            results.append(self.interpret_output(net_output[i]))

        #返回處理后的結果
        return results
復制代碼

 5、interpret_output()函數

該函數對yolo網絡輸出的結果進行處理,提取出有目標的邊界框,方便后續的處理。

復制代碼
 def interpret_output(self, output):
        '''
        對yolo網絡輸出進行處理  
        
        args:
            output:yolo網絡輸出的每一行數據 大小為[1470,]
                    0:7*7*20:表示預測類別   
                    7*7*20:7*7*20 + 7*7*2:表示預測置信度,即預測的邊界框與實際邊界框之間的IOU
                    7*7*20 + 7*7*2:1470:預測邊界框    目標中心是相對於當前格子的,寬度和高度的開根號是相對當前整張圖像的(歸一化的)        
                    
        return:
             result:yolo網絡目標檢測到的邊界框,list類型 每一個元素對應一個目標框 
                  包含{類別名,x_center,y_center,w,h,置信度}   實際上這個置信度是yolo網絡輸出的置信度confidence和預測對應的類別概率的乘積
        '''
        #[7,7,2,20]
        probs = np.zeros((self.cell_size, self.cell_size,
                          self.boxes_per_cell, self.num_class))
        #類別概率 [7,7,20]
        class_probs = np.reshape(
            output[0:self.boundary1],
            (self.cell_size, self.cell_size, self.num_class))
        #置信度 [7,7,2]
        scales = np.reshape(
            output[self.boundary1:self.boundary2],
            (self.cell_size, self.cell_size, self.boxes_per_cell))
        #邊界框 [7,7,2,4]
        boxes = np.reshape(
            output[self.boundary2:],
            (self.cell_size, self.cell_size, self.boxes_per_cell, 4))
        #[14,7]  每一行[0,1,2,3,4,5,6]
        offset = np.array(
            [np.arange(self.cell_size)] * self.cell_size * self.boxes_per_cell)
        #[7,7,2] 每一行都是  [[0,0],[1,1],[2,2],[3,3],[4,4],[5,5],[6,6]]
        offset = np.transpose(
            np.reshape(
                offset,
                [self.boxes_per_cell, self.cell_size, self.cell_size]),
            (1, 2, 0))

        #目標中心是相對於整個圖片的
        boxes[:, :, :, 0] += offset
        boxes[:, :, :, 1] += np.transpose(offset, (1, 0, 2))
        boxes[:, :, :, :2] = 1.0 * boxes[:, :, :, 0:2] / self.cell_size
        #寬度、高度相對整個圖片的
        boxes[:, :, :, 2:] = np.square(boxes[:, :, :, 2:])

        #轉換成實際的編輯框(沒有歸一化的)
        boxes *= self.image_size

        #遍歷每一個邊界框的置信度
        for i in range(self.boxes_per_cell):
            #遍歷每一個類別
            for j in range(self.num_class):
                #在測試時,乘以條件類概率和單個盒子的置信度預測,這些分數編碼了j類出現在框i中的概率以及預測框擬合目標的程度。
                probs[:, :, i, j] = np.multiply(
                    class_probs[:, :, j], scales[:, :, i])

        #[7,7,2,20] 如果第i個邊界框檢測到類別j 則[;,;,i,j]=1
        filter_mat_probs = np.array(probs >= self.threshold, dtype='bool')
        #返回filter_mat_probs非0值的索引 返回4個List,每個list長度為n  即檢測到的邊界框的個數      
        filter_mat_boxes = np.nonzero(filter_mat_probs)
        #獲取檢測到目標的邊界框 [n,4]  n表示邊界框的個數
        boxes_filtered = boxes[filter_mat_boxes[0],
                               filter_mat_boxes[1], filter_mat_boxes[2]]        
        #獲取檢測到目標的邊界框的置信度 (n,)
        probs_filtered = probs[filter_mat_probs]  
        #獲取檢測到目標的邊界框對應的目標類別 (n,)
        classes_num_filtered = np.argmax(
            filter_mat_probs, axis=3)[
            filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]]    
        #按置信度倒序排序,返回對應的索引
        argsort = np.array(np.argsort(probs_filtered))[::-1]
        boxes_filtered = boxes_filtered[argsort]
        probs_filtered = probs_filtered[argsort]
        classes_num_filtered = classes_num_filtered[argsort]

        for i in range(len(boxes_filtered)):
            if probs_filtered[i] == 0:
                continue
            for j in range(i + 1, len(boxes_filtered)):
                #計算n各邊界框,兩兩之間的IoU是否大於閾值,非極大值抑制                          
                if self.iou(boxes_filtered[i], boxes_filtered[j]) :
                    probs_filtered[j] = 0.0

        #非極大值抑制后的輸出
        filter_iou = np.array(probs_filtered > 0.0, dtype='bool')
        boxes_filtered = boxes_filtered[filter_iou]
        probs_filtered = probs_filtered[filter_iou]
        classes_num_filtered = classes_num_filtered[filter_iou]

        result = []
        #遍歷每一個邊界框
        for i in range(len(boxes_filtered)):
            result.append(
                [self.classes[classes_num_filtered[i]],  #類別名
                 boxes_filtered[i][0],                   #x中心
                 boxes_filtered[i][1],                   #y中心
                 boxes_filtered[i][2],                   #寬度
                 boxes_filtered[i][3],                   #高度
                 probs_filtered[i]])                     #置信度  

        return result
復制代碼

6、iou()函數

計算兩個邊界框的IoU值。

復制代碼
    def iou(self, box1, box2):
        '''
        計算兩個邊界框的IoU
        
        args:
            box1:邊界框1  [4,]   真實值
            box2:邊界框2  [4,]   真實值
        '''
        tb = min(box1[0] + 0.5 * box1[2], box2[0] + 0.5 * box2[2]) - \
            max(box1[0] - 0.5 * box1[2], box2[0] - 0.5 * box2[2])
        lr = min(box1[1] + 0.5 * box1[3], box2[1] + 0.5 * box2[3]) - \
            max(box1[1] - 0.5 * box1[3], box2[1] - 0.5 * box2[3])
        inter = 0 if tb < 0 or lr < 0 else tb * lr
        return inter / (box1[2] * box1[3] + box2[2] * box2[3] - inter)
復制代碼

7、camera_detector()函數

調用攝像頭實現實時目標檢測。

復制代碼
    def camera_detector(self, cap, wait=10):
        '''
        打開攝像頭,實時檢測
        
        '''
        #測試時間
        detect_timer = Timer()
        #讀取一幀
        ret, _ = cap.read()

        while ret:
            #讀取一幀
            ret, frame = cap.read()
            #測試其實時間
            detect_timer.tic()
            result = self.detect(frame)
            #測試結束時間
            detect_timer.toc()
            print('Average detecting time: {:.3f}s'.format(
                detect_timer.average_time))
            #繪制邊界框,以及添加附加信息
            self.draw_result(frame, result)
            #顯示
            cv2.imshow('Camera', frame)
            cv2.waitKey(wait)
復制代碼

8、image_detector()函數

對圖片進行目標檢測。

復制代碼
    def image_detector(self, imname, wait=0):
        '''
        目標檢測
        
        args:
            imname:測試圖片路徑
        '''
        #檢測時間
        detect_timer = Timer()
        #讀取圖片
        image = cv2.imread(imname)
        #image = cv2.resize(image,(int(image.shape[1]/2),int(image.shape[0]/2)))
        #檢測的起始時間
        detect_timer.tic()
        #開始檢測
        result = self.detect(image)
        #檢測的結束時間
        detect_timer.toc()
        print('Average detecting time: {:.3f}s'.format(
            detect_timer.average_time))
        #繪制檢測結果
        self.draw_result(image, result)
        cv2.imshow('Image', image)
        cv2.waitKey(wait)
復制代碼

介紹完了Detector這個類,我們來看一下main函數。該函數比較檢測,首先解析命令行參數,然后創建yolo網絡,以及檢測器對象,最后調用image_detector()函數對圖片進行目標檢測。

復制代碼
def main():
    #創建一個解析器對象,並告訴它將會有些什么參數。當程序運行時,該解析器就可以用於處理命令行參數。
    #https://www.cnblogs.com/lovemyspring/p/3214598.html
    parser = argparse.ArgumentParser()
    #定義參數
    parser.add_argument('--weights', default="YOLO_small.ckpt", type=str)
    parser.add_argument('--weight_dir', default='weights', type=str)
    parser.add_argument('--data_dir', default="data", type=str)
    parser.add_argument('--gpu', default='', type=str)
    #定義了所有參數之后,你就可以給 parse_args() 傳遞一組參數字符串來解析命令行。默認情況下,參數是從 sys.argv[1:] 中獲取
    #parse_args() 的返回值是一個命名空間,包含傳遞給命令的參數。該對象將參數保存其屬性
    args = parser.parse_args()

    #設置環境變量
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    #創建YOLO網絡對象
    yolo = YOLONet(False)
    #加載檢查點文件
    weight_file = os.path.join(args.data_dir, args.weight_dir, args.weights)
    weight_file = './data/pascal_voc/weights/YOLO_small.ckpt' 
    #weight_file = './data/pascal_voc/output/2018_07_09_17_00/yolo.ckpt-1000'
    
    #創建測試對象
    detector = Detector(yolo, weight_file)

    # detect from camera
    # cap = cv2.VideoCapture(-1)
    # detector.camera_detector(cap)

    # detect from image file
    imname = 'test/car.jpg'
    detector.image_detector(imname)
復制代碼

我們執行如下代碼,開始測試網絡:

if __name__ == '__main__':
    tf.reset_default_graph()
    main()

 

我們可以看到yolo網絡對小目標檢測效果並不好,漏檢了一個目標。這主要與yolo的網絡結構以及損失函數有關。除此之外yolo網絡還有一些其他缺點,我們總結如下:

  • 漏檢。每個網格只預測一個類別的邊界框,而且最后只取置信度最大的那個邊界框。這就導致如果多個不同物體(或者同類物體的不同實體)的中心落在同一個網格中,會造成漏檢。yolo對相互靠的很近的物體,還有很小的群體檢測效果不好,這是因為一個網格中只預測了兩個框,並且只屬於一類。
  • 位置精准性差。召回率低。由於損失函數的問題,定位誤差是影響檢測效果的主要原因。尤其是大小物體的處理上,還有待加強。
  • 對測試圖像中,同一類物體出現的新的不常見的長寬比和其他情況是。泛化能力偏弱。

參考文章:

[1]argparse - 命令行選項與參數解析(轉)

[2]Yolo v1詳解及相關問題解答


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM