inds_inside = np.where( (all_anchors[:, 0] >= -self._allowed_border) & (all_anchors[:, 1] >= -self._allowed_border) & (all_anchors[:, 2] < im_info[1] + self._allowed_border) & # width (all_anchors[:, 3] < im_info[0] + self._allowed_border) # height )[0] # keep only inside anchors anchors = all_anchors[inds_inside, :]
這部分代碼是把所有anchor中超過了圖片邊界部分的anchor去掉,即論文中說的cross-boundary anchors
# fg label: for each gt, anchor with highest overlap labels[gt_argmax_overlaps] = 1 # fg label: above threshold IOU labels[max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = 1
這部分代碼是把和gt-roi有最大iou的anchor和與任何gt-roi iou大於0.7的anchor的label置為1,即前景。這和論文中所說的是一樣的。
if cfg.TRAIN.RPN_CLOBBER_POSITIVES: # assign bg labels last so that negative labels can clobber positives labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0
把和所有gt-roi iou都小於0.3的achor的label置為0
# label: 1 is positive, 0 is negative, -1 is dont care labels = np.empty((len(inds_inside), ), dtype=np.float32) labels.fill(-1)
這是label的初始化的代碼,所有的label都置為-1
所以總的來看,label分為3類,一類是0,即背景label;一類是1,即前景label;另一類既不是前景也不是背景,置為-1。論文中說只有前景和背景對訓練目標有用,這種-1的label對訓練沒用。
# subsample positive labels if we have too many num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCHSIZE) fg_inds = np.where(labels == 1)[0] if len(fg_inds) > num_fg: #從所有label為1的anchor中選擇128個,剩下的anchor的label全部置為-1 disable_inds = npr.choice( fg_inds, size=(len(fg_inds) - num_fg), replace=False) labels[disable_inds] = -1 # subsample negative labels if we have too many num_bg = cfg.TRAIN.RPN_BATCHSIZE - np.sum(labels == 1)#這里num_bg不是直接設為128,而是256減去label為1的個數,這樣如果label為1的不夠,就用label為0的填充,這個代碼實現很巧 bg_inds = np.where(labels == 0)[0] if len(bg_inds) > num_bg: #將沒被選擇作為訓練的anchor的label置為-1 disable_inds = npr.choice( bg_inds, size=(len(bg_inds) - num_bg), replace=False) labels[disable_inds] = -1 #print "was %s inds, disabling %s, now %s inds" % ( #len(bg_inds), len(disable_inds), np.sum(labels == 0))
論文中說從所有anchor中隨機選取256個anchor,前景128個,背景128個。注意:那種label為-1的不會當前景也不會當背景。
這兩段代碼是前一部分是在所有前景的anchor中選128個,后一部分是在所有的背景achor中選128個。如果前景的個數少於了128個,就把所有的anchor選出來,差的由背景部分補。這和fast rcnn選取roi一樣。
這是論文中rpn的loss函數:
這個loss函數和fast rcnn中的loss函數差不多,所以在計算的時候是每個坐標單獨進行smoothL1計算,所以參數Pi*和Nreg必須弄成4維的向量,並不是在論文中的就一個數值
bbox_inside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32) bbox_inside_weights[labels == 1, :] = np.array(cfg.TRAIN.RPN_BBOX_INSIDE_WEIGHTS) bbox_outside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32) if cfg.TRAIN.RPN_POSITIVE_WEIGHT < 0: # uniform weighting of examples (given non-uniform sampling) num_examples = np.sum(labels >= 0) positive_weights = np.ones((1, 4)) * 1.0 / num_examples negative_weights = np.ones((1, 4)) * 1.0 / num_examples else: assert ((cfg.TRAIN.RPN_POSITIVE_WEIGHT > 0) & (cfg.TRAIN.RPN_POSITIVE_WEIGHT < 1)) positive_weights = (cfg.TRAIN.RPN_POSITIVE_WEIGHT / np.sum(labels == 1)) negative_weights = ((1.0 - cfg.TRAIN.RPN_POSITIVE_WEIGHT) / np.sum(labels == 0)) bbox_outside_weights[labels == 1, :] = positive_weights bbox_outside_weights[labels == 0, :] = negative_weights
bbox_inside_weights實際上指的就是Pi*,bbox_outside_weights指的是Nreg。
論文中說如果anchor是前景,Pi*就是1,為背景,Pi*就是0。label為-1的,在這個代碼來看也是設置為0,應該是在后面不會參與計算,這個設置為多少都無所謂。
Nreg是進行標准化操作,就是取平均。這個平均是把所有的label 0和label 1加起來。因為選的是256個anchor做訓練,所以實際上這個值是1/256。
值得注意的是,rpn網絡的訓練是256個anchor,128個positive,128個negative。但anchor_target_layer層的輸出並不是只有256個anchor的label和坐標變換,而是所有的anchor。其中_unmap函數就很好體現了這一點。那訓練的時候怎么實現訓練這256個呢?實際上,這一層的4個輸出,rpn_labels是需要輸出到rpn_loss_cls層,其他的3個輸出到rpn_loss_bbox,label實際上就是loss function前半部分中的Pi*(即計算分類的loss),這是一個log loss,為-1的label是無法進行log計算的,剩下的0、1就直接計算,這一部分實現了256。loss function后半部分是計算bbox坐標的loss,Pi*,也就是bbox_inside_weights,論文中說了activated only for positive anchors,只有為正例的anchor才去計算坐標的損失,這是Pi*是1,其他情況都是0
bbox_inside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)
bbox_inside_weights[labels == 1, :] = np.array(cfg.TRAIN.RPN_BBOX_INSIDE_WEIGHTS)
這段代碼也體現了這個思想,所以這也實現了256。
可以這樣去理解:anchor_target_layer輸出的是所有anchor的label,bbox_targets。但真正進行了loss計算的只有那256個anchor。可以看下面這個loss函數,i是anchor的下標,這個i計算是計算了所有的anchor的,但只有那256個才真正改變了loss值,其他的都是0。
_unmap函數:因為all_anchors裁減掉了2/3左右,僅僅保留在圖像內的anchor。這里就是將其復原作為下一層的輸入了,並reshape成相應的格式。