contrastive loss


昨天Ke Kiaming大神的新文章 MoCo問世,reID中的contrastive loss逐漸往上游影響。自己對這一塊一直沒有一個總結梳理,趁着學習這篇文章的機會整理一下,挖個坑慢慢填


Distance metric learning aims to learn an embedding representation of the data that preserves the distance between similar data points close and dissimilar data points far on the embedding space1

1. Improved Deep Metric Learning with Multi-class N-pair Loss Objective [NIP2016] [pdf] [code-pytorch]

  • 定義: 

    • 樣本數據$x \in \mathcal{X}$,標簽為$y \in \{1,2,...,L\}$,$x_+,x_-$分別表述輸入樣本的正負樣本對(同類/不同類);$f(.;\theta): \mathcal{X} \rightarrow \mathbb{R}^K$表示feature embedding,$f(x)$是feature embedding vector;$m$是margin
    • Contrastive loss takes pairs of examples as input and trains a network to predict whether two inputs are from the same class or not: $$\mathcal{L}^{m}_{cont}(x_i,x_j;f)=\mathbb{1}\{y_i=y_j\}||f_i-f_j||_2^2+\mathbb{1}\{y_i \ne y_j\}max(0,m-||f_i-f_j||_2^2)^2$$
    • Triplet loss shares a similar spirit to contrastive loss, but is composed of triplets, each consisting of a query, a positive example (to the query), and a negative example:$$\mathcal{L}^{m}_{cont}(x,x_+,x_-;f)=max(0, ||f-f_+||_2^2+||f-f_-||_2^2)^2+m$$

  • 方法

    • 針對之前的相關方法每次只有一個負樣本訓練存在的收斂慢、局部最優(可通過hard negative mining解決)的問題,提出了multi-class N-pair loss和negative class mining
    • multi-class N-pair loss
      • 每個batch選$N$個class,每個class選一對樣本,即$\{(x_1, x_1^+),...,(x_N, x_N^+)\}$,建立N個tuplets: $\{S_i\}_{i=1}^N$,其中$S_i={x_i, x_1^+, x_2^+, ..., x_N^+}$構成了一對正樣本和N-1對負樣本
      • 損失函數$$\mathcal{L}_{N-pair-mc}(\{(x_i,x_i^+)\}_{i=1}^N;f)= \frac{1}{N} \sum_{i=1}^N log(1+\sum_{j \ne i} exp(f_i^T f_j^+ - f_i^T f_i^+))$$
      • 而$$log(1+\sum_{i=1}^{L-1} exp(f^T f_i - f^T f^+))=-log \frac{exp(f^T f^+)}{exp(f^T f^+)+ \sum_{i=1}^{L-1}exp(f^T f_i )}$$ 和softmax非常像!
    • negative class mining
      1. Evaluate Embedding Vectors: choose randomly a large number of output classes C; for each class, randomly pass a few (one or two) examples to extract their embedding vectors.
      2. Select Negative Classes: select one class randomly from C classes from step 1. Next, greedily add a new class that violates triplet constraint the most w.r.t. the selected classes till we reach N classes. When a tie appears, we randomly pick one of tied classes. 這一步不是很懂,大概是每次從剩下類中找最hard的類的sample,直到N?
      3. Finalize N-pair: draw two examples from each selected class from step 2.

2. Unsupervised Feature Learning via Non-Parametric Instance Discrimination [pdf] [code-pytorch]

 

3. Momentum Contrast for Unsupervised Visual Representation Learning

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM