之前,對SSD的論文進行了解讀,可以回顧之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html。
為了加深對SSD的理解,因此對SSD的源碼進行了復現,主要參考的github項目是ssd.pytorch。同時,我自己對該項目增加了大量注釋:https://github.com/Dengshunge/mySSD_pytorch
搭建SSD的項目,可以分成以下三個部分:
接下來,本篇博客重點分析損失函數的構建。
檢測任務的損失函數,與分類任務的損失函數具有很大不同。在檢測的損失函數中,不僅需要計類別置信度的差異,坐標的差異,還需要使用到各種tricks,例如hard negative mining等。
在train.py中,首先需要對損失函數MultiBoxLoss()進行初始化,需要傳入的參數為num_classes類別數,正例的IOU閾值和hard negative mining的正負樣本比例。在論文中,VOC的類別總數是21(20個類別加上1個背景);當預測框與GT框的IOU大於0.5時,認為該預測框是正例;hard negative mining的正樣本和負樣本的比例是1:3。
# 損失函數 criterion = MultiBoxLoss(num_classes=voc['num_classes'], overlap_thresh=0.5, neg_pos=3)
在models/multibox_loss中,定義了損失函數MultiBoxLoss()。在函數forward()中,需要傳進來兩個參數,分別是predictions和targets,其中,predictions是SSD網絡得到的結果,分別是預測框坐標,類別置信度和先驗錨點框;而targets是則是數據讀取中的值,是GT框的坐標和類別label。首先,需要創建坐標loc_t和類別置信度conf_t的tensor,其shape分別是[batch_size,8732,4]和[batch_size,8732]。然后,使用一個for循環,將GT框與先驗錨點框的坐標與label進行match,得到每個錨點框的label和坐標偏差,並將結果保存與loc_t和conf_t中。由於制定了某些錨點框用於預測目標,因此,接下來,需要使用這部分錨點框信息來計算損失。取出含目標的錨點框,得到其index,其中,pos的shape為[batch_size,8732],每個元素是true或者false。再從網絡預測的8732個預測框中,取出同樣index的預測框的坐標偏差loc_p,而loc_t則是同樣index的先驗錨點框的坐標偏差。由於錨點框對應上了,則使用smooth_l1來計算預測框回歸的算是loss_l,如下圖所示的$L_{loc}$,圖片來源。
接下來,則是使用hard negative mining和計算置信度損失。首先為模型預測出來的置信度conf_data進行維度變換,由[batch_size,8732,21]變成[batch_size*8732,21]的batch_conf,應該是為了方便下面進行計算。接下來,計算所有預測框的置信度損失loss_c,將含目標的錨點框(正例)的損失置0,並對損失進行排名,從而選出損失最大的前num_neg個損失的index。將正例的pos_index和損失最大的負例neg_idx提取出來成conf_p,用於參與訓練中,與相同index的先驗錨點框進行計算交叉熵損失計算。最后將置信度損失和位置損失返回。
class MultiBoxLoss(nn.Module): ''' SSD損失函數的計算 ''' def __init__(self, num_classes, overlap_thresh, neg_pos): super(MultiBoxLoss, self).__init__() self.num_classes = num_classes # 類別數 self.threshold = overlap_thresh # GT框與先驗錨點框的閾值 self.negpos_ratio = neg_pos # 負例的比例 def forward(self, predictions, targets): ''' 對損失函數進行計算: 1.進行GT框與先驗錨點框的匹配,得到loc_t和conf_t,分別表示錨點框需要匹配的坐標和錨點框需要匹配的label 2.對包含目標的先驗錨點框loc_t(即正例)與預測的loc_data計算位置損失函數 3.對負例(即背景)進行損失計算,選擇損失最大的num_neg個負例和正例共同組成訓練樣本,取出這些訓練樣本的錨點框targets_weighted 與置信度預測值conf_p,計算置信度損失: a)為Hard Negative Mining計算最大置信度loss_c b)將loss_c中正例對應的值置0,即保留了所有負例 c)對此loss_c進行排序,得到損失最大的idx_rank d)計算用於訓練的負例的個數num_neg,約為正例的3倍 e)選擇idx_rank中前num_neg個用作訓練 f)將正例的index和負例的index共同組成用於計算損失的index,並從預測置信度conf_data和真實置信度conf_t提出這些樣本,形成 conf_p和targets_weighted,計算兩者的置信度損失. :param predictions: 一個元祖,包含位置預測,置信度預測,先驗錨點框 位置預測:(batch_size,num_priors,4),即[batch_size,8732,4] 置信度預測:(batch_size,num_priors,num_classes),即[batch_size, 8732, 21] 先驗錨點框:(num_priors,4),即[8732, 4] :param targets: 真實框的坐標與label,[batch_size,num_objs,5] 其中,5代表[xmin,ymin,xmia,ymax,label] ''' loc_data, conf_data, priors = predictions num = loc_data.shape[0] # 即batch_size大小 priors = priors[:loc_data.shape[1], :] # 取出8732個錨點框,與位置預測的錨點框數量相同 num_priors = priors.shape[0] # 8732 loc_t = torch.Tensor(num, num_priors, 4) # [batch_size,8732,4],生成隨機tensor,后續用於填充 conf_t = torch.Tensor(num, num_priors) # [batch_size,8732] # 取消梯度更新,貌似默認是False loc_t.requires_grad = False conf_t.requires_grad = False for idx in range(num): truths = targets[idx][:, :-1] # 坐標值,[xmin,ymin,xmia,ymax] labels = targets[idx][:, -1] # label defaults = priors.cuda() match(self.threshold, truths, defaults, labels, loc_t, conf_t, idx) if torch.cuda.is_available(): loc_t = loc_t.cuda() conf_t = conf_t.cuda() # shape:[batch_size,8732],其元素組成是類別標簽號和背景 pos = conf_t > 0 # 排除label=0,即排除背景,shape[batch_size,8732],其元素組成是true或者false # Localization Loss (Smooth L1),定位損失函數 # Shape: [batch,num_priors,4] # pos.dim()表示pos有多少維,應該是一個定值(2) # pos由[batch_size,8732]變成[batch_size,8732,1],然后展開成[batch_size,8732,4] pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) loc_p = loc_data[pos_idx].view(-1, 4) # [num_pos,4],取出帶目標的這些框 loc_t = loc_t[pos_idx].view(-1, 4) # [num_pos,4] # 位置損失函數 loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') # 這里對損失值是相加,有公式可知,還沒到相除的地步 # 為Hard Negative Mining計算max conf across batch batch_conf = conf_data.view(-1, self.num_classes) # shape[batch_size*8732,21] # gather函數的作用是沿着定軸dim(1),按照Index(conf_t.view(-1, 1))取出元素 # batch_conf.gather(1, conf_t.view(-1, 1))的shape[8732,1],作用是得到每個錨點框在匹配GT框后的label loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1).long()) # 這個不是最終的置信度損失函數 # Hard Negative Mining # 由於正例與負例的數據不均衡,因此不是所有負例都用於訓練 loss_c[pos.view(-1, 1)] = 0 # pos與loss_c維度不一樣,所以需要轉換一下,選出負例 loss_c = loss_c.view(num, -1) # [batch_size,8732] _, loss_idx = loss_c.sort(1, descending=True) # 得到降序排列的index _, idx_rank = loss_idx.sort(1) num_pos = pos.sum(1, keepdim=True) # pos里面是true或者false,因此sum后的結果應該是包含的目標數量 num_neg = torch.clamp(self.negpos_ratio * num_pos, max=pos.size(1) - 1) # 生成一個隨機數用於表示負例的數量,正例和負例的比例約3:1 neg = idx_rank < num_neg.expand_as(idx_rank) # [batch_size,8732] 選擇num_neg個負例,其元素組成是true或者false # 置信度損失,包括正例和負例 # [batch_size, 8732, 21],元素組成是true或者false,但true代表着存在目標,其對應的index為label pos_idx = pos.unsqueeze(2).expand_as(conf_data) neg_idx = neg.unsqueeze(2).expand_as(conf_data) # pos_idx由true和false組成,表示選擇出來的正例,neg_idx同理 # (pos_idx + neg_idx)表示選擇出來用於訓練的樣例,包含正例和反例 # torch.gt(other)函數的作用是逐個元素與other進行大小比較,大於則為true,否則為false # 因此conf_data[(pos_idx + neg_idx).gt(0)]得到了所有用於訓練的樣例 conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes) targets_weighted = conf_t[(pos + neg).gt(0)] loss_c = F.cross_entropy(conf_p, targets_weighted.long(), reduction='sum') # L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = num_pos.sum() # 一個batch里面所有正例的數量 loss_l /= N loss_c /= N return loss_l, loss_c
在hard negative mining中,需要先計算loss_c。從代碼可以看到 loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1).long()) ,這句代碼就是置信度損失的計算,可以參考公式進行理解。這里可以提及一下,對loss_c的兩次排序,參考這篇博客,首先對值進行降序排序,得到排名1,然后對排名又進行降序排序,得到排名2,如下圖所示,即能取出idx_rank的前N個,可獲得損失最大那些值,即變量neg的作用。
在計算損失函數時,提及了函數match(),這個函數位於models/box_utils.py中,是一個非常關鍵的函數,對應論文的匹配策略那一章節,其作用是為每個錨點框指定GT框和為每個GT框指定錨點框。需要傳進來幾個參數,truths是GT框的坐標,priors是先驗錨點框的坐標[中心點x,中心點y,W,H],labels是GT框對應的類別(不包含背景),loc_t和conf_t是用來保存結果的,idx是第i張圖片。
為了方便表述,num_objects表示一張圖中,GT框的數量;num_priors表示先驗錨點框的數量,即8732。
第一步,由於先驗錨點框priors的坐標形式是[中心點x,中心點y,W,H],需要使用函數point_from()來將其轉化成[x_min,y_min,x_max,y_max]。然后計算每個GT框與所有先驗錨點框的jaccard值,即IOU的值,使用了numpy風格的計算方式,返回的變量overlaps的shape為[GT框數量,8732]。
第二步,根據論文,為每個GT框匹配一個最大IOU的先驗錨點框,確保每個GT框至少有一個錨點框進行預測。
第三步,為每個錨點框匹配上一個最大IOU的GT框來進行預測。
第四步,變量best_truth_overlap保存着每個框與GT框的最大IOU值(第三步的結果),使用index_fill()函數,將第二步的結果同步到這個變量中。在index_fill()函數中,使用數值2來進行填充,是為了確保第二步中得到的錨點框肯定會被選到。對變量best_truth_idx也進行同樣的處理。
第五步,由於傳入進來的labels的類別是從0開始的,SSD中認為0應該是背景,所以,需要對labels進行加一。這里需要注意一下,best_truth_idx的shape是[8732],每個元素的范圍為[0,num_objects],所以conf的shape為[num_priors],每個元素表示先驗錨點框的label(0是背景)。同時,需要將變量best_truth_overlap中IOU小於閾值(0.5)的錨點框的label設置為0。並將結果保存與conf_t,返回給外面的函數用於計算。
第六步,同樣需要將GT框的坐標進行擴展,形成shape為[num_priors,4]的matches,這樣每個錨點框都有對應的坐標進行預測,但最終並不是每個錨點框都用於訓練中。
第七步,使用GT框與錨點框進行編碼,對應論文中的公式2,得到shape為[num_priors,4]的值,即偏差,將此結果返回出去。
注意,這里使用的是GT框的信息和先驗錨點框的信息,並沒有涉及到網絡預測出來的結果。得到每個錨點框的類別conf_t和坐標loc_t。由於沒有用到網絡預測的結果,可以認為這部分一直都是定值。
def match(threshold, truths, priors, labels, loc_t, conf_t, idx): ''' 這個函數對應論文中的matching strategy匹配策略.SSD需要為每一個先驗錨點框都指定一個label, 這個label或者指向背景,或者指向每個類別. 論文中的匹配策略是: 1.首先,每個GT框選擇與其IOU最大的一個錨點框,並令這個錨點框的label等於這個GT框的label 2.然后,當錨點框與GT框的IOU大於閾值(0.5)時,同樣令這個錨點框的label等於這個GT框的label 因此,代碼上的邏輯為: 1.計算每個GT框與每個錨點框的IOU,得到一個shape為[num_object,num_priors]的矩陣overlaps 2.選擇與GT框的IOU最大的錨點框,錨點框的index為best_prior_idx,對應的IOU值為best_prior_overlap 3.為每一個錨點框選擇一個IOU最大的GT框,可能會出現多個錨點框匹配一個GT框的情況,此時,每個錨點框對應GT框的index為best_truth_idx, 對應的IOU為best_truth_overlap.注意,此時IOU值可能會存在小於閾值的情況. 4.第3步可能到導致存在GT框沒有與錨點框匹配上的情況,所以要和第2步進行結合.在第3步的基礎上,對best_truth_overlap進行選擇,選擇出 best_prior_idx這些錨點框,讓其對其的IOU等於一個大於1的定值;並且讓best_truth_idx中index為best_prior_idx的錨點框的label 與GT框對應上.最終,best_truth_overlap表示每個錨點框與GT框的最大IOU值,而best_truth_idx表示每個錨點框用於與相應的GT框進行 匹配. 5.第4步中,會存在IOU小於閾值的情況,要將這些小於IOU閾值的錨點框的label指向背景,完成第二條匹配策略. labels表示GT框對應的標簽號,"conf=labels[best_truth_idx]+1"得到每個錨點框對應的標簽號,其中label=0是背景. "conf[best_truth_overlap < threshold] = 0"則將小於IOU閾值的錨點框的label指向背景 6.得到的conf表示每個錨點框對應的label,還需要一個矩陣,來表示每個錨點框需要匹配GT框的坐標. truths表示GT框的坐標,"matches = truths[best_truth_idx]"得到每個錨點框需要匹配GT框的坐標. :param threshold:IOU的閾值 :param truths:GT框的坐標,shape:[num_obj,4] :param priors:先驗錨點框的坐標,shape:[num_priors,4],num_priors=8732 :param labels:這些GT框對應的label,shape:[num_obj],此時label=0還不是背景 :param loc_t:坐標結果會保存在這個tensor :param conf_t:置信度結果會保存在這個tensor :param idx:結果保存的idx ''' # 第1步,計算IOU overlaps = jaccard(truths, point_from(priors)) # shape:[num_object,num_priors] # 第2步,為每個真實框匹配一個IOU最大的錨點框,GT框->錨點框 # best_prior_overlap為每個真實框的最大IOU值,shape[num_objects,1] # best_prior_idx為對應的最大IOU的先驗錨點框的Index,其元素值的范圍為[0,num_priors] best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) # 第3步,若先驗錨點框與GT框的IOU>閾值,也將這些錨點框匹配上,錨點框->GT框 # best_truth_overlap為每個先驗錨點框對應其中一個真實框的最大IOU,shape[1,num_priors] # best_truth_idx為每個先驗錨點框對應的真實框的index,其元素值的范圍為[0,num_objects] best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) best_prior_idx.squeeze_(1) # [num_objects] best_prior_overlap.squeeze_(1) # [num_objects] best_truth_idx.squeeze_(0) # [num_priors],8732 best_truth_overlap.squeeze_(0) # [num_priors],8732 # 第4步 # index_fill_(self, dim: _int, index: Tensor, value: Number)對第dim行的index使用value進行填充 # best_truth_overlap為第一步匹配的結果,需要使用到,使用best_prior_idx是第二步的結果,也是需要使用上的 # 所以在best_truth_overlap上進行填充,表明選出來的正例 # 使用2進行填充,是因為,IOU值的范圍是[0,1],只要使用大於1的值填充,就表明肯定能被選出來 best_truth_overlap.index_fill_(0, best_prior_idx, 2) # 確定最佳先驗錨點框 # 確保每個GT框都能匹配上最大IOU的先驗錨點框 # 得到每個先驗錨點框都能有一個匹配上的數字 # best_prior_idx的元素值的范圍是[0,num_priors],長度為num_objects for j in range(best_prior_idx.size(0)): best_truth_idx[best_prior_idx[j]] = j # 第5步 conf = labels[best_truth_idx] + 1 # Shape: [num_priors],0為背景,所以其余編號+1 conf[best_truth_overlap < threshold] = 0 # 置信度小於閾值的label設置為0 # 第6步 matches = truths[best_truth_idx] # 取出最佳匹配的GT框,Shape: [num_priors,4] # 進行位置編碼 loc = encode(matches, priors,voc['variance']) loc_t[idx] = loc # [num_priors,4],應該學習的編碼偏差 conf_t[idx] = conf # [num_priors],每個錨點框的label
在函數match()中,使用到了函數encode()來對位置進行編碼。參考博客和R-CNN中的公式,假設先驗錨點框的坐標為$(d^{cx},d^{cy},d^w,d^h)$,預測框的坐標為$(b^{cx},b^{cy},b^w,b^h)$,則預測框的轉換值l為:
$$l^{cx}=(b^{cx}-d^{cx})/d^w, l^{cy}=(b^{cy}-d^{cy})/d^h$$
$$b^w=d^wexp(l^x), b^h=d^hexp(l^h)$$
而代碼中,我們利用了方差的信息,因此進行了相應的調整,整體上是一致的。
def encode(matched, priors, variances): ''' 對坐標進行編碼,對應論文中的公式2 利用GT框和先驗錨點框,計算偏差,用於回歸 :param matched: 每個先驗錨點框對應最佳的GT框,Shape: [num_priors, 4], 其中4代表[xmin,ymin,xmax,ymax] :param priors: 先驗錨點框,Shape: [num_priors,4], 其中4代表[中心點x,中心點y,寬,高] :return: shape:[num_priors, 4] ''' g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] # 計算GT框與錨點框中心點的距離 g_cxcy /= (variances[0] * priors[:, 2:]) g_wh = (matched[:, 2:] - matched[:, :2]) # xmax-xmin,ymax-ymin g_wh /= priors[:, 2:] g_wh = torch.log(g_wh) / variances[1] return torch.cat([g_cxcy, g_wh], 1)
至此,SSD的損失函數構建以介紹完成。相比於分類任務,目標檢測的損失函數構建需要更多的代碼,包含了各種tricks。