論文鏈接:https://arxiv.org/abs/1706.07567
開源代碼:https://github.com/chaoyuaw/incubator-mxnet/tree/master/example/gluon/embedding_learning
Sampling Matters in Deep Embedding Learning
Abstract
深度嵌入(embeddings)回答了一個簡單的問題:兩幅圖像有多相似?學習這些嵌入是驗證、zero-shot學習和可視化搜索的基礎。最突出的優化深度卷積網絡的方法是使用合適的損失函數,如contrastive損失或triplet損失。雖然大量的工作只集中在損失函數上,但我們在本文中表明,選擇訓練樣本的方法也起到同樣重要的作用。我們提出了距離加權抽樣(distance weightedsampling)方法,它選擇了比傳統方法有着更多的信息和更穩定的例子。此外,我們證明了一個簡單的基於邊際損失足以勝過所有其他損失函數。我們在Stanford Online Products,CAR196,和CUB200-2011數據集上評估我們的方法在圖像檢索和聚類的效果,以及在LFW數據集上評估人臉驗證的效果。我們的方法在所有這些方面都達到了最先進的性能。
1. Introduction
將圖像轉換成豐富的語義表征的模型是現代計算機視覺的核心,其應用范圍從zero-learning學習[5,41]和視覺搜索[3,10,22,29],到人臉識別[6,23,25,28]或細粒度檢索[22,28,29]。經過訓練以尊重pairwise關系的深度網絡已成為最成功的嵌入模型[4,6,25,34]。
深度嵌入學習的核心思想很簡單:將相似的圖像在嵌入空間中拉近,將不同的圖像分開。例如,contrastive損失[11]迫使所有的positive圖像接近,而所有的negatives應該被分開一定的固定距離。然而,對所有圖像使用相同的固定距離是非常有限制的,阻止了嵌入空間中的任何扭曲。這就導致了triplet損失的出現,其只需要在每個例子[25]的基礎上,使得negative圖像比positive圖像更遠即可。triplet損失是目前標准嵌入任務中表現最好的損失之一[22,25,45]。與pairwise損失不同的是,triplet損失不僅改變了孤立的損失函數,還改變了正樣本和負樣本的選擇方式。這就給我們提供了兩個問題:損失和采樣策略的選擇。如圖1所示。
(這篇論文給的采樣方法為Distance weighted sampling, 損失為margin based loss)
在本文中,我們證明了樣本選擇在嵌入學習中所起的作用與樣本損失的作用相同或更重要。例如,對於相同的損失函數,不同的采樣策略會導致截然不同的解決方案。同時,在良好的采樣策略下,許多不同的損失函數的表現相似:如果使用相同的采樣策略,contrastive損失和triplet損失的效果幾乎一樣好。在本文中,我們分析了現有的抽樣策略,並說明了它們為什么有效和為什么無效。然后,我們提出了一種新的抽樣策略,即根據樣本之間的相對距離均勻地抽取樣本。這就糾正了嵌入空間的幾何形狀所引起的偏差,同時保證了任何數據點都有被采樣的機會。我們提出的采樣導致了較低的梯度方差,從而穩定了訓練,導致質量更好的嵌入,而不用考慮損失函數。
損失函數顯然也很重要。我們提出了一個簡單的基於邊際的損失作為contrastive損失的擴展。它只鼓勵所有的正樣本在一個距離內,而不是盡可能接近。它放寬了損失,使它更有活力。此外,通過使用isotonic回歸,我們的邊際損失關注的是相對順序而不是絕對距離。
我們基於邊際損失和距離加權抽樣的方法在Stanford Online Products、CARS196和CUB200-2011數據集上實現了最先進的圖像檢索和聚類性能。它也優於以前在使用標准公開可用的訓練數據的LFW人臉驗證數據[16]上的最先進的結果。我們的損失函數和采樣策略都是易於實現且訓練有效的。
2. Related Work
利用神經網絡提取某些關系特征的想法可以追溯到90年代。Siamese網絡[4]找到一個嵌入空間,這樣類似的例子有類似的嵌入,反之亦然。這種網絡是端到端訓練的,在所有映射之間共享權值。Siamese網絡首先應用於簽名驗證,后來擴展到人臉驗證和降維[6,11]中。然而,考慮到當時有限的計算能力和它們的非凸特性,這些方法最初並沒有受到太多關注。convex方法更為流行[7,39]。例如,triplet損失[26,36]就是凸優化中出現的最突出的方法之一。
有了足夠的數據和計算能力,這兩個學派的思想結合成一個使用triplet損失的Siamese結構。這使得人臉驗證的表現接近人類[23,25]。受triplet損失啟發,一些人對更多的例子實施了限制。例如,PDDM[15]和直方圖損失[34]使用了quadruplets。除此之外,n-pair損失[28]和Lifted Structure[22]定義了一個batch中所有圖像的約束條件。
這種過多的損失函數很容易讓人想起信息檢索中的排序問題。有一個individual、pair-wise[14],和list-wise[35]方法的組合被用來最大化相關性。值得注意的是isotonic回歸方法,它解開了pairwise比較的糾纏,從而提高了計算效率。有關概述,請參閱[21]。
一些論文探討了其他性質的建模。Structural Clustering[29]優化了集群質量。PDDM[15]提出了一種新的構建局部特征結構的模塊。HDC[41]訓練一個集成來建模不同“hard levels”的例子。相反,我們在這里表明,如果與正確的抽樣策略配對,一個簡單的pairwise損失就夠用了。
樣本選擇技術的研究相對較少。對於contrastive損失,通常從所有可能的對中隨機選擇[3,6,11],有時采用難負樣本挖掘(hard negative mining)[27]方法。對於triplet損失,廣泛采用在FaceNet[25]中提出的semi-hard negative mining方法[22,23]。以加速收斂到相同的全局損失函數為目標,研究了隨機優化[43]的抽樣問題。相反,在嵌入學習中,采樣實際上改變了考慮的整體損失函數。在這篇文章中,我們展示了采樣如何影響實際的深度嵌入學習的表現。
3. Preliminaries
讓f(xi)表示一個數據點的embedding,其中
是一個帶有參數Θ的可微深度網絡。為了訓練的穩定[25],f(xi)經常被歸一化到單位長度。我們的目標是學習一個embedding去讓相似數據點更接近,同時將不相似數據點分開。形式上來說,我們定義兩個數據點之間的距離為
,其中||.||表示歐式范數。對於任意yij=1的正數據點對,其距離應該小;對於yij=0的負數據點對,距離應該大。
contrastive損失通過鼓勵所有正對距離接近0,並保持負對距離在一定的閾值之上的方法直接優化了這個距離:
contrastive損失的一個缺點是我們必須為所有的負樣本對選擇一個恆定的邊界α。這意味着視覺上不同的類與視覺上相似的類一樣嵌入在同樣小的空間中。嵌入空間不允許扭曲。
相比之下,triplet損失只是試圖保持每個樣本的所有的正對都比其任意負對更接近樣本:
這種形式允許嵌入空間被任意扭曲,並且不施加一個恆定的邊界α限制。
從風險最小化的角度來看,我們的目標可能是分別優化所有O(n2) pairs或O(n3) triplets的總損失。這是:
這在計算上是不可行的。此外,一旦網絡收斂,大多數樣本的貢獻很小。
這導致了許多加速收斂的heuristics的出現。對於contrastive損失,hard negative mining通常具有較快的收斂速度。對於triplet損失,則不那么明顯,因為hard negative mining往往會導致模型崩潰,即所有圖像的嵌入都是相同的。因此,FaceNet[25]提出使用一種有點神秘的semi-hard negative mining:給定一個anchor a和一個正例p,通過:
在一個batch中獲得一個負例n。這樣會產生一個違反規則的示例,這個示例是hard的,但還不算too hard。batch構建也很重要。為了獲得更多信息的triplet,FaceNet batch size設置為1800,並確保每個identity在batch[25]中大約有40幅圖像。甚至如何在一個batch中最好地選擇triplets也不是很清楚。Parkhi等人[23]使用在線選擇方法,因此每一對(a, p)只抽樣一個triplet。OpenFace[2]采用了離線triplet選擇,因此一個batch中有三分之一的圖像分別作為anchor、正例和負例。
簡而言之,抽樣很重要。它通過加權樣本隱式地定義了一個相當heuristic的目標函數。這種方法使得復制和擴展insights到不同的數據集、不同的優化框架或不同的體系結構變得困難。在下一節中,我們將分析其中一些技術,並解釋為什么它們能夠提供更好的結果。然后,我們提出了一種新的采樣策略,以超越當前的最先進結果。
4. Distance Weighted Margin-Based Loss
為了理解均勻采樣負例時發生了什么,回想一下我們的嵌入值通常被限制在n維單位球中,n≥128。考慮這些點在球上均勻分布的情況。在這種情況下,pairwise距離的分布如下:
推導可見[1]。圖2展示了測量出現的集中情況。實際上,在高維空間中,q(d)接近。另一方面來說,如果負樣本均勻分散,且對其隨機采樣,我們可能會獲得
的樣本(即遠離
的樣本)。對於小於
的閾值,其將導致無損失產生,因此學習沒有進展。已學習的嵌入遵循非常類似的分布,因此也適用於相同的推理。詳見補充材料。
too hard的負例會引起另一個問題。考慮一個負對t:= (a, n)或一個triplet t:= (a, p, n)。關於負例f(xn)的梯度為:
用於一些函數w(.)和。第一項
決定梯度的方向。當|| han ||比較小時,會出現一個問題,且我們embedding的估計是有噪音的(意思是負樣本距離太小則負樣本的梯度方向容易受噪聲影響,梯度大小接近於0。 所以可能既走不動又可能走錯,距離太大則沒有意義,所以會需要各個距離的,那為了得到各種距離的,最直接就是采樣的概率與出現的概率成反比) 。給定充足地由訓練算法引入的噪音z,方向
被噪音控制。圖3a顯示了用於
梯度方向的協方差矩陣的核范數。我們可以看到,當負例太close/hard時,梯度方差大(方差低才好),信噪比低。與此同時,隨機樣本往往相距太遠,不能產生良好的信號。
Distance weighted sampling. 因此,我們提出了一個新的采樣分布,糾正偏差,同時控制方差。具體來說,我們按距離均勻采樣,即使用權值為q(d)−1的采樣。這給了我們分散,而不是聚集在一個小區域的例子。為了避免噪聲樣本,我們對加權樣本進行了剪切(即距離太小或太大的負例會被去掉)。形式上,給定一個anchor 樣例a, 使用下面公式距離加權采樣 樣本負對(a, n*):
(以圖2舉例,即如果距離Dan遠離(即小於或大於),這樣q(Dan)越小的,q-1(Dan)的值則越大,只要不大於λ,則取其為該負例被選取的概率;因此可知遠離
但又不是很遠,即Dan沒小到接近0或大到特別大的情況下的負例被采樣的概率最高。如果距離Dan接近
,則q(Dan)值比較大,因此q-1(Dan)的值比較小,選取這種負樣本的概率小。
總結就是,因為距離Dan接近的負例多,因此其選取概率低;遠離的負例少,因此選取的概率高。這樣就能均衡選取到相近數量的各種距離Dan的負例)
圖3b對比了不同策略下得到的模擬樣例的梯度變化情況。hard negative mining總是在高方差區域(負例距離太近)提供實例。這導致了噪聲梯度,不能有效地分開兩個例子,因此得到一個爆炸模型。隨機抽樣只產生不造成損失(負例距離過大)的簡單例子。semi-hard negative mining在兩者之間找到一個狹窄的集合。雖然它可能在開始時迅速收斂,但在某些時候,band內沒有留下任何例子,網絡將停止前進。FaceNet報告了一個一致的發現:在某個點之后,損失的減少速度急劇減慢,他們的最終系統花了80天來訓練[25]。距離加權抽樣提供了廣泛的樣本,從而在控制方差的同時穩定地產生信息樣本。在第5節中,我們將看到距離加權抽樣在幾乎所有測試的損失函數中都帶來了性能改進。當然,抽樣只能解決一半的問題,但它使我們能夠分析各種損失函數。
圖4a和圖4b描述了contrastive損失和triplet損失。有兩個關鍵的區別,這通常解釋了為什么triplet損失優於contrastive損失:triplet損失沒有假設一個預定義的閾值來分離相似和不同的圖像。相反,它可以靈活地扭曲空間以容忍outliers,並適應不同類的不同級別的類內差異。第二,triplet損失只要求正例比負例更接近,而contrastive損失則需要努力將所有正例盡可能接近。后者是不必要的。畢竟,對於大多數應用程序,包括圖像檢索、聚類和驗證,維護正確的相對關系就足夠了。
(縱軸是損失,橫軸是樣本對距離。藍色表示的就是當正對距離越小越接近於0時,梯度變化就越小,對模型更新影響小;越大則梯度變化大,對模型更新影響大。綠色即表示負對距離過小時,說明這是個hard negative,所以梯度變化就大,對模型更新影響大;但是如果距離過大,那么梯度接近0,說明這個負對沒什么用)
(舉例說明,可見圖4(c)對應的就是下面的式子,圖的結果的意思是:當正對距離Dap小於Dan-α時,正對的梯度為0,大於則梯度為1;當負對距離Dan小於Dap-α時,梯度為-1,大於時梯度為0)
另一方面,在圖4b中,我們也觀察到,在負例中,triplet損失的損失函數呈凹形。特別要注意的是,對於hard negatives(帶有小Dan),相應的負例的梯度接近於零。在這種情況下,不難看出為什么hard negative mining會導致一個崩潰的模型:因為hard positive pairs給出了很大的吸引梯度,而hard negative pairs給出了很小的排斥梯度,所以所有點最終聚集到同一個點。為了使來自所有距離的例子的損失穩定,一個簡單的補救方法是使用而不是
,即:
圖4c顯示了該損失函數。它對於任何嵌入f(x)的梯度長度都是1。關於使用固定長度梯度的好處的更多討論見例[12,20]。如第5節所示,這種簡單的固定加上距離加權抽樣方法的效果已經超過了傳統的triplet損失。
Margin based loss. 這些觀察激勵我們設計一個損失函數,該函數享有triplet損失的靈活性,其有一個適合於來自所有距離例子的形狀,同時提供了contrastive損失的計算效率。其基本思想可以追溯到ordinal回歸中,即只有分數的相對順序是重要的[17]。也就是說,我們只需要知道兩個集合之間的交叉。Isotonic回歸通過分別估計閾值來利用這一點,然后對相對於閾值的分數進行懲罰。我們使用相同的技巧,現在應用於pairwise距離而不是分數函數。基於邊際的自適應損失定義為:
(所以對於正對來說,損失為α+ Dap - β;而對於負對來說,損失變為α- Dan + β,當Dij=β時,損失相等。圖4(d)的意思是當正對距離Dap小於 β-α 時,此時損失小於0,梯度為0;大於β-α時梯度才為1。當負對距離Dan小於β+α時,梯度為-1;大於則損失小於0,梯度為0)
(所以並不要求所有正例盡可能的小(這里要求大於β-α),因為有時候要允許類內的差異(細粒度識別),而且有時候任務只需要相對關系(圖像檢索))
在這里,β是一個變量,它決定正對和負對之間的邊界,α控制分離的邊界,而yij∈{−1,1}。圖4d顯示了這個新的損失函數。我們可以看到,它與contrastive損失相比,放寬了對正例的限制。它有效地在位移距離Dij−β上施加了很大的邊際損失的。這種損失與支持向量分類器(SVC)[8]非常相似。
為了享受類似triplet損失的靈活性,我們需要一個更靈活的邊界參數β,它依賴於特定類β(class)和特定樣本β(img)項。
尤其是,特定樣本偏差在triplet loss中和閾值是一樣的角色。手動選擇所有的
和
是很不靈活的。相反,我們想要共同學習這些參數。幸運的是,β的梯度可以簡單地計算為:
很明顯,更大的β值更可取,因為它們可以更好地利用嵌入空間。因此,為了正則化β,我們引入了超參數v,這導致了優化問題最終為:
在這里,算法調整了違反左右邊界的點數量之間的差異。這可以通過觀察到它們的梯度需要在一個最佳β中抵消來看出。請注意,這里使用的v非常類似於v-svm[24]中的v技巧。
Relationship to isotonic regression. 優化基於邊際的損失可以看作是解決一個距離排序問題。從技術上講,它與信息檢索[21,44]中的排序學習問題相似。為了查看第一個注意事項-最佳β,經驗風險可以寫為:
這是一個定義在絕對誤差上的isotonic回歸。我們看到,基於邊際損失是保持相對orders的“minimum-effort”更新的數量。它側重於相對關系,即側重於正對距離和負對距離的分離。這與傳統的損失函數(如contrastive損失)相反,在contrastive損失函數中,損失是相對於預定義的閾值定義的。
5. Experiments
對該方法在圖像檢索、聚類和驗證等方面進行了評價。在圖像檢索和聚類方面,我們使用Stanford Online Products[22]、CARS196[19]和CUB200-2011[37]數據集,遵循Song等[22]的實驗設置。The Stanford Online Product 數據集包含22,634個類別的120,053幅圖像。前11,318個類別用於訓練,其余的用於測試。CARS196數據集包含196個模型的16,185張汽車圖像。我們使用前98個模型進行訓練,其余的用於測試。CUB200-2011數據集包含了200種鳥類的11,788幅圖像。前100個物種用於訓練,其余的用於測試。
我們基於標准Recall@k度量來評估圖像檢索質量,像Song et al.[22]一樣。給定ground-truth聚類,我們使用NMI分數
去評估聚類對齊
的質量。其中I(.,.)和H(.)分別表示互信息和熵。我是用K-means算法進行聚類。
為了驗證,我們在最大的公開人臉數據集CASIA-WebFace[40]上訓練我們的模型,並在標准的LFW[16]數據集上評估。VGG人臉數據集[23]更大,但是它的許多鏈接已經過期。CASIA-WebFace數據集包含494,414張10,575個人的圖像。LFW數據集由13,233張5,749個人的圖像組成。它的驗證基准包含6000個驗證對,分成10個子集。我們根據剩下的9個split為一個split選擇驗證閾值。
除特別說明外,我們在所有實驗中使用的嵌入尺寸為128,輸入圖像尺寸為224×224。所有的模型都使用Adam[18]進行訓練,人臉驗證的批量大小設置為200,Stanford Online Products設置為80,其他實驗設置為128。網絡體系結構遵循ResNet-50(預激活)[13]。為了加快訓練速度,我們在人臉驗證實驗中使用了一個簡化版的ResNet-50。具體來說,我們在5個階段中分別使用了64、96、192、384、768個過濾器,而不是最初提出的64、256、512、1024、2048個過濾器。我們沒有觀察到由於更改導致的任何明顯的性能下降。采用水平鏡像和256×256的隨機裁剪進行數據增強。在測試期間,我們使用單一中心裁剪防范。人臉圖像由MTCNN[42]進行對齊。當對齊失敗時,我們使用中心裁剪。遵循FaceNet[25],我們使用了α= 0.2,對於基於邊際的損失,我們初始化了β(0)=1.2和β(class)=β(img)=0。
請注意,以前的一些論文使用了提供的邊框,而其他的則沒有使用。為了公平地與以前的方法進行比較,我們在原始圖像和邊界框裁剪的圖像上評估我們的方法。對於CARS196數據集,我們將裁剪后的圖像縮放到256×256。對於CUB200,我們縮放和填充圖像,使其較長的邊為256像素,保持長寬比固定。
我們的batch構建遵循FaceNet[25]。我們一個batch中每個類使用m=5個正樣本。一個batch中的所有正對都被取樣。對於正對中的每一個例子,我們采樣一個負對。這確保了正對和負對的數量是平衡的,並且每個示例都屬於相同數量的正對和相同數量的負對。
5.1. Ablation study
我們首先了解損失函數、自適應邊際和特定的功能選擇的影響。我們專注於Stanford Online Products,因為它是三個圖像檢索數據集中最大的。注意,圖像檢索傾向於triplet損失而不是contrastive損失,因為只有相對關系matters。這里所有的模型都是從頭開始訓練的。由於不同的方法以不同的速度收斂,所有方法都訓練了100個epoch,並在它們的最佳epoch而不是在訓練結束時報告性能。
我們在我們的距離加權采樣方法中對隨機抽樣和semi-hard negative mining兩種采樣方法進行比較。對於semi-hard negative mining,對pairwise損失函數的距離下界沒有自然選擇。因此在本實驗中,我們使用0.5的下界來模擬triplet損失的正距離。我們考慮contrastive損失、triplet損失和我們基於margin的損失。所謂隨機抽樣,是指從所有正對和負對中均勻抽樣。由於這樣的定義不適用於triplet損失,我們只測試contrastive損失和我們基於margin的損失。
結果如表1所示。我們看到,給定相同的損失函數,不同的抽樣分布會導致非常不同的性能。特別地,雖然在隨機抽樣中,contrastive損失產生的結果比triplet損失的差得多,但當使用類似於triplet損失的抽樣程序時,它的性能顯著提高。這一證據駁斥了對contrastive損失與triplet損失的常見誤解:triplet損失的強度不僅來自損失函數本身,更重要的是來自所伴隨的抽樣方法。此外,距離加權采樣方法始終如一地為幾乎所有的損失函數提供了性能提升。唯一的例外是contrastive損失。我們發現它對超參數非常敏感。雖然我們為隨機抽樣和semi-hard negative mining找到了良好的超參數,但我們還不能為距離加權抽樣找到性能良好的超參數。另一方面,基於邊際的損失自動學習一個合適的β並很好地進行訓練。值得注意的是,不論抽樣策略如何,基於邊際的損失在很大程度上優於其他損失函數。這些觀察結果適用於多個批處理大小,如表2所示。我們還嘗試使用ILSVRC 2012-CLS[9]數據集對我們的模型進行預訓練,這在之前的工作中很常見[3,22]。預訓練可以提高10%的recall。在下面的部分中,我們將重點討論預先訓練好的模型,以便進行公平比較。
接下來,我們對這些方法進行定性評價。圖5顯示了對隨機選擇的查詢圖像的檢索結果。我們可以看到,triplet損失一般能提供合理的結果,但也會出現一些錯誤。另一方面,我們的方法給出了更准確的結果。
通過學習一個靈活邊界β去評估所獲得的收益,我們比較了使用一個固定β建模和使用可學習的βs建模的效果。結果如表3所示。我們可以看到,使用更靈活的特定類的β(class)確實比各種固定β(0)值更有優勢。我們還使用特定示例的β(img)進行了測試,但實驗沒有得出結論。我們推測,學習特定示例的β(img)可能引入了太多參數,導致過擬合。
Convergence speed. 進一步分析了采樣對收斂速度的影響。我們將我們使用距離加權抽樣的基於邊際的損失方法與兩種最常用的深度嵌入方法相比較:即semi-hard抽樣的triplet損失和隨機抽樣的contrastive損失。學習曲線如圖6所示。我們看到,semi-hard negative mining訓練的triplet損失收斂速度較慢,因為它忽略了太多的例子。隨機抽樣的contrastive損失損失收斂得更慢。距離加權抽樣使用信息更豐富、更穩定的樣本,收斂速度更快、更准確。
Time complexity of sampling 采樣的計算代價可以忽略不計。在Tesla P100 GPU上,前向和后向傳播每個batch(大小120)大約需要0.55秒。semi-hard采樣只需要0.00031秒,距離加權采樣只需要0.0043秒,即使在我們的單線程CPU實現中也是如此。兩種策略都取O (nm(n−m)),其中n為批大小,m為批中每類圖像的數量。
5.2. Quantitative Results
我們現在將我們的方法與其他最先進的方法進行比較。圖像檢索和聚類結果見表4、表5、表6。我們可以看到,我們的模型在所有三個數據集中取得了最好的性能。特別是,基於邊際的損失優於triplet損失的擴展,如LiftedStruct [22], StructClustering [29], N-pair[28],和PDDM[15]。它的性能也優於histogram損失[34],后者需要計算相似直方圖。還要注意的是,我們的模型只對每個圖像使用了一個128維的嵌入。這比HDC[41]更簡潔,HDC[41]為每張圖像使用3個嵌入向量。
表7給出了人臉驗證的結果。在CASIA-WebFace上訓練的所有模型中,我們的模型達到了最好的精度。還需要注意的是,在這里我們的方法優於使用廣泛訓練程序的模型。MFM[38]使用softmax分類損失。CASIA[40]使用softmax損失和contrastive損失的組合。N-pair[28]使用一個更昂貴的損失函數,該函數在批處理中對所有對進行定義。我們還列出了一些其他的先進的結果,他們不能純粹地作為參考。DeepID2[30]和DeepID3[31]根據人臉landmarks位置在25個人臉區域上使用25個網絡。當只使用一個網絡進行訓練時,它們的性能會顯著下降。其他的模型如FaceNet[25]和Deep-Face[33]都是在巨大的私有數據集上訓練的。
總體而言,我們的模型在所有比較方法中、在所有數據集上都取得了最好的結果。值得注意的是,我們的方法使用了最簡單的損失函數——即contrastive損失的一個簡單變體。
6. Conclusion
我們證明了在深度嵌入學習中,抽樣比損失函數更重要。這並不奇怪,因為隱式定義的損失函數(相當明顯)是一個加權樣本對象。
我們的新距離加權抽樣方法提高了多個損失函數的性能。此外,我們分析並提供了一個簡單的基於邊際的損失,它放寬了傳統contrastive損失的不必要約束,並享有triplet損失的靈活性。我們表明,距離加權抽樣和基於邊際損失顯着優於所有其他損失函數。
代碼實現:

class DistanceWeightedSampling(HybridBlock): r"""Distance weighted sampling. See "sampling matters in deep embedding learning" paper for details. Parameters ---------- batch_k : int Number of images per class. Inputs: - **data**: input tensor with shape (batch_size, embed_dim). Here we assume the consecutive batch_k examples are of the same class. For example, if batch_k = 5, the first 5 examples belong to the same class, 6th-10th examples belong to another class, etc. Outputs: - a_indices: indices of anchors. - x[a_indices]: sampled anchor embeddings. - x[p_indices]: sampled positive embeddings. - x[n_indices]: sampled negative embeddings. - x: embeddings of the input batch. """ def __init__(self, batch_k, cutoff=0.5, nonzero_loss_cutoff=1.4, **kwargs): self.batch_k = batch_k self.cutoff = cutoff # We sample only from negatives that induce a non-zero loss. # These are negatives with a distance < nonzero_loss_cutoff. # With a margin-based loss, nonzero_loss_cutoff == margin + beta. self.nonzero_loss_cutoff = nonzero_loss_cutoff super(DistanceWeightedSampling, self).__init__(**kwargs) def hybrid_forward(self, F, x): k = self.batch_k n, d = x.shape distance = get_distance(F, x) # Cut off to avoid high variance. distance = F.maximum(distance, self.cutoff) # Subtract max(log(distance)) for stability. log_weights = ((2.0 - float(d)) * F.log(distance) - (float(d - 3) / 2) * F.log(1.0 - 0.25 * (distance ** 2.0))) weights = F.exp(log_weights - F.max(log_weights)) # Sample only negative examples by setting weights of # the same-class examples to 0. mask = np.ones(weights.shape) for i in range(0, n, k): mask[i:i+k, i:i+k] = 0 weights = weights * F.array(mask) * (distance < self.nonzero_loss_cutoff) weights = weights / F.sum(weights, axis=1, keepdims=True) a_indices = [] p_indices = [] n_indices = [] np_weights = weights.asnumpy() #size is (batch_size * batch_size),即np_weights[2,5]表示選擇第6張圖作為第3張圖負例的權重 for i in range(n): block_idx = i // k try: n_indices += np.random.choice(n, k-1, p=np_weights[i]).tolist() #即[0:n]這個圖以np_weights[i]一一對應的權重采樣,取k-1個做負例 except: n_indices += np.random.choice(n, k-1).tolist() for j in range(block_idx * k, (block_idx + 1) * k): if j != i: #因為每K個batch_k是同一個類,所以當a圖為i是,取非i的其他k-1張圖作為正例 a_indices.append(i) p_indices.append(j) return a_indices, x[a_indices], x[p_indices], x[n_indices], x def __repr__(self): s = '{name}({batch_k})' return s.format(name=self.__class__.__name__, **self.__dict__)