Faster rcnn代碼理解(2)


接着上篇的博客,咱們繼續看一下Faster RCNN的代碼~

上次大致講完了Faster rcnn在訓練時是如何獲取imdb和roidb文件的,主要都在train_rpn()的get_roidb()函數中,train_rpn()函數后面的部分基本沒什么需要講的了,那我們再回到訓練流程中來:

這一步訓練的網絡結構見下圖:

訓練的第一步就這么完成了(RPN網絡使用gt_roidb訓練完成),還有,這里的train_rpn()函數中有涉及到train_net()函數,即用來訓練得到網絡模型,我會在訓練fast rcnn時給出講解;之后再進行第二步:

產生proposal的網絡結構如下:

這一步是利用上一步訓練好的RPN網絡來生成proposals供下一步中的fast rcnn訓練,這里補充一下,字典mp_kwargs中的參數solver,是使用get_solvers()函數得到的,見下:

由上圖可知,四步訓練時的網絡結構都在相應的四個solver文件有解釋。好了回到訓練流程中,產生proposal的函數是rpn_generate()函數,那我們進入這個函數:

首先設置了什么預NMS(不太懂),還有就是經過NMS后產生2000個proposals,然后初始化caffe,再用get_imdb()函數得到imdb數據,方法和前面一樣,再用caffe.NET()加載RPN網絡,再使用imdb_proposals()得到proposal,那我們就進入這個函數:

該函數的作用就是在所有的圖片上生成proposal,不過作者又嵌套了一個im_proposals()函數,即在一張圖片上產生proposals(這個嵌套看似多余,但是為后面再添一些測試腳本提供了方便,我猜rgb大神可能有這個目的),進入im_proposals()函數中:

見上圖,首先用_get_image_blob()函數將圖片數據轉換為caffe的blob格式,進入該函數:

最終得到的blob格式為(batch elem , channel , height , width),im_info格式為[M,N,im_scale],其中im_scale是縮放比例,原始圖片輸入faster rcnn中進行訓練時都需要先縮放成統一的規格;再回到im_proposals()函數中,使用net.forward()函數進行一次前向傳播,獲得proposals,ok。之后再回到imdb_proposals()函數中,最后返回得到的imdb_boxes,即我們從RPN上產生的proposals。再回到rpn_generate()函數中,接着就是將生成的proposals保存並傳輸到多線程中去供下一步訓練使用,這個函數使命就暫時完成了;

再到第三步,訓練fast rcnn網絡,見下:

這一步訓練的網絡結構如下圖:

注意這時候訓練fast rcnn的roi就是前面訓練好的RPN網絡生成的proposals了(訓練RPN網絡用的是gt框),相關配置都在mp_kwargs字典中,很明顯,我們要進入train_fast_rcnn()中一探究竟:

這里首先設置了訓練使用的rpn_roidb方法(RPN用的是gt_roidb方法),由於這時候的cfg.TRAIN.PROPOSAL_METHOD變成了rpn_roidb,所以相應的使用的get_roidb()也相應地改變,此時使用rpn_roidb()方法,進入該函數:

該方法首先先獲得gt_roidb,然后再用_load_rpn_roidb()獲得由RPN產生的roidb,進入該函數:

接着進入create_roidb_from_box_list()函數中:

這個函數凸顯了數據結構的重要性,我們需要重點關注一下這其中'gt_overlaps'、'argmaxes'、'maxes'、'overlaps'、'I'的結構,對於理解這個函數費仲重要,最好在紙上寫出來~將RPN產生的proposal制作成roidb后,再回到rpn_roidb()中,使用merge_roidbs()將gt_roidb和rpn_roidb進行組合,見該函數:

這樣子就得到了最終訓練fast rcnn所需要的roidb數據,ok~再回到train_fast_rcnn()中,接着我們就來看看train_net()這個函數,進入該函數:

首先是使用filter_roidb()對之前產生的用於訓練fast_rcnn的roidb再進行一次篩選,具體過程參見該篩選函數:

篩選過之后,再回到train_net()函數中,創建一個solverWrapper對象,其中就是訓練得到的網絡模型,進入這個類中:

上圖是它的類定義中的一部分,我們先來看看它的初始化函數,這里需要注意的是add_bbox_regression_targets()這個函數,它的作用是為RPN產生的proposal提供回歸屬性,該函數向roidb中再添加一個key:'bbox_targets',它的格式如:targets[][5]:第一個元素是label,后面四個元素就是論文中談及的tx,ty,tw,th;好的,我們進入這個函數:

上圖是改函數的前半部分,主要看_compute_targets()函數,它產生了回歸屬性,進入該函數:

產生了回歸屬性,OK,再回到add_bbox_regression_targets()函數中,看后面剩下的部分:

這部分主要得到rpn_roidb的坐標的均值和方差,可以用來進行坐標歸一化;OK,再回到SolverWrapper類中,剩下的則是snapshot快照方法,和train_model方法,回到train_net()函數中,接着再調用train_model()方法,進入該函數:

上圖的函數就是使用SGD得到訓練模型也就是我們需要的fast_rcnn網絡模型,好了,train_net()函數就介紹到這兒了~

再回到train_fast_rcnn()函數中,剩下的都是保存之類的,那我們再回到訓練流程中,剩下的幾步訓練流程就如法炮制了,見下:

這兩步訓練的網絡結構見下圖:

這一步訓練的網絡結構見下圖:

這樣子,就通過分步訓練得到了最終的網絡模型,最后就是一些收尾工作了:

好了,終於全部弄完,接下來我們就來看作者在網絡結構中添加的那幾層了~ 

(轉載請注明出處)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM