論文解讀(SimCLR)《A Simple Framework for Contrastive Learning of Visual Representations》


論文信息

論文標題:Deep Graph Clustering via Dual Correlation Reduction
論文作者:Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey E. Hinton
論文來源:2020, ICML
論文地址:download
論文代碼:download

1 介紹

  本文主要介紹 SimCLR框架。

  定義:

    SimCLR:一個簡單的視覺表示對比學習框架,不僅比以前的工作更出色,而且也更簡單,既不需要專門的架構,也不需要儲存庫。

  性能:

  在 $ImageNet$ 上大大優於以前的自監督和半監督學習方法。在 $SimCLR$ 學習的自監督表示上訓練的 線性分類器 實現了 $76.5%$ 的 $top-1$ 准確率,相對於之前的最新技術水平提高了 $7 \%$,與監督 $ResNet-50$(下圖中的 $gray cross$) 的性能相匹配。當僅對 $1 \% $ 的標簽進行微調時,我們實現了 $85.8\%$ 的 $top-5$ 准確率,在標簽數量減少 $100$ 倍的情況下優於 $AlexNet$。

       

  圖解:
  1、隨着模型的增大($Parameters$的增加),$SimCLR$ 的性能也在不斷的增加,體現了 $SimCLR$ 貢獻3:“對比學習的好處在於使用更大的批量和更多的訓練步驟”。

  2、$SimCLR$ 性能在 $ImageNet$ 上的性能遠高於其他方法除$Sup ResNet50$。

 SimCLR框架優勢:

  • 多個數據增強組合對於定義產生有效表示的對比預測任務至關重要。 此外,無監督的對比學習受益於比監督學習更強的數據增強
  • 在表示和對比損失之間引入可學習的非線性變換,大大提高了學習表示的質量。
  • 具有對比交叉熵損失的表示學習受益於歸一化嵌入適當的溫度參數 $\tau$
  • 對比學習與監督學習相比,受益於更大的批量($Batch$)和更長的訓練時間。 與監督學習一樣,對比學習也受益於更深更廣的網絡

3 方法

3.1 對比學習框架

  對比學習是一種為機器學習模型描述相似和不同事物的任務的方法。它試圖教機器區分相似和不同的事物。

      

      

  $SimCLR$ 最終目的是最大化同一數據示例的不同增強視圖之間的一致性來學習表示,即 $max \  similar(\mathbf{v_1} ,\mathbf{v_2} )$

      

   SimCLR 框架包括以下四個主要組件:

  1、隨機數據增強模塊。隨機轉換任何給定的數據示例,生成同一數據示例的兩個相關視圖表示並定義  ${\widetilde{x}_i }$和 ${\widetilde{x}_j}$ 是正對。本文組合應用三種增強:隨機裁剪然后調整回原始大小(random cropping and resize back)隨機顏色失真(color distortions) 和 隨機高斯模糊(random Gaussian blur)

  2、基礎編碼器(base encoder) $f(\cdot )$。用於從生成的視圖中提取表示向量,允許選擇各種網絡架構。本文選擇 $ResNet$ 獲得$h_i=f(\widetilde{x}_i )=ResNet(\widetilde{x}_i)$,生成的表示$h_i \in R^d$是平均池化層$(average pooling layer)$后的輸出。

  3、投影頭(projection head) $g(·) $將表示映射到應用對比損失的空間。 本文使用一個帶有一個隱藏層的 $MLP$ 來獲得 $z_i =g(h_i)=w^{(2)}\sigma (w^{(1)}h_i)$ 其中 $\sigma$是一個 $ReLU$ 非線性函數。此外,發現在 $z_i $而非 $h_i $ 上定義對比損失是有益的。

   4、對比損失函數(contrastive loss function)。 給定 $batch$ 中一組生成的視圖 $\{\widetilde{x}_k \}$,其中包括一對正例 ${\widetilde{x}_i }$ 和 ${\widetilde{x}_j}$ ,對比預測任務旨在對給定 ${\widetilde{x}_i}$ 識別 $\{{\widetilde{x}_j} \}_{k\ne i } $ 中的${\widetilde{x}_j}$ 。

  隨機抽取 $N$ 個樣本的小批量樣本,並在從小批量樣本上生成增強視圖,從而產生 $2N$ 個數據點。 本文無明確地指定負例,而是給定一個正對$(positive pair)$,將小批量中的其他 $2N − 2 $個增強示例視為負示例。本文定義相似度為余弦相似度$sim(u,v)=\frac{u^Tv}{||u||\ ||v||} $。則一對正對 $(i,j)$的損失函數定義為:

    $l_{i,j}=-log( \frac{exp(sim(z_i,z_j)/\tau )}{ {\textstyle \sum_{k=1}^{2N}}1_{[k\ne i]} \ exp(sim(z_i,z_j)/\tau)  } )$

  其中 $1_{[k\ne i]} \in \{0,1\}$ 是指示函數,當 $k\ne i$ 為 $1$ 。$\tau$是溫度參數。最終損失是在小批量中計算所有正對 $(i,j)$ 和 $(j,i)$ 的。為方便起見,將其稱為 $NT-Xent$(歸一化溫度標度交叉熵損失)。

  算法流程

      

  圖解算法流程:

  Step1:隨機數據增強模塊
  首先,原始圖像數據集生成若干大小為 $N$ 的 $batch$。這里假設取一批大小為 $N = 2$ 的 $batch$。本文使用 $8192$ 的大 $batch$。
      
  定義隨機數據增強函數 $T$ ,本文應用 $random (crop and resize back + color distortions + Gaussian blur)$。

      

  對於 $batch$ 中的每一幅圖像,使用隨機數據增強函數 $T$ 得到一對$view$。對 $batch$ 為 $2$ 的情況,得到 $2N = 4$ 張圖像。

      

  Step2基礎編碼器(base encoder) $f(\cdot )$
  對增強過的圖像通過一個編碼器來獲得圖像表示。所使用的編碼器是通用的,可與其他架構替換。下面的兩個編碼器共享權值,得到表示$vector$ $h_i$和$h_j$。

      

  在本文中,作者使用 $ResNet-50$ 架構作為編碼器。輸出是一個 $2048$ 維的向量 $h$。

      

  Step投影頭(projection head) $g(·) $將表示映射到應用對比損失的空間。
  本文使用一個帶有一個隱藏層的 $MLP$ 來獲得 $z_i =g(h_i)=w^{(2)}\sigma (w^{(1)}h_i)$ 其中 $\sigma$是一個 $ReLU$ 非線性函數。

      

  Step4使用對比損失函數進行模型調優。
  對於 $batch$ 中的每個增強過的圖像通過基礎編碼器 $f(\cdot  )$,得到嵌入向量 $z$。

      

   使用嵌入向量$z_i$,計算損失的步驟如下: 

  a. 計算余弦相似性

  用余弦相似度計算圖像的兩個增強的圖像之間的相似度。對於兩個增強的圖像 $x_i$ 和 $x_j$,在其投影表示 $z_i$ 和 $z_j$ 上計算余弦相似度。

      

    $s_{i,j} = \frac{ \color{#ff7070}{z_{i}^{T}z_{j}} }{( ||\color{#ff7070}{z_{i}}|| ||\color{#ff7070}{z_{j}}||)}$
  其中

  • $\lVert z_{i} \rVert$是矢量的模。

   使用上述公式計算 $batch$ 中每個增強圖像之間的兩兩余弦相似度。如圖所示,在理想情況下,增強后的貓的圖像之間的相似度會很高,而貓和大象圖像之間的相似度會較低。

       

  b. 損失的計算
  $SimCLR$使用了一種對比損失,稱為“$NT-Xent$損失”(歸一化溫度-尺度交叉熵損失)。工作步驟如下:

  首先,將 $batch$ 的增強對逐個取出。

      

  接下來,我們使用和 $softmax$ 函數原理相似的函數來得到這兩個圖像相似的概率。

      

   這種 $softmax$ 計算等效於獲得第二張增強貓圖像與該對中的第一張貓圖像最相似的概率。批次中的所有剩余圖像都被采樣為不同的圖像(負對)。 因此,我們不需要像 $InstDisc$、$MoCo$ 或 $PIRL$ 等以前的方法那樣需要專門的架構、存儲庫或隊列。

      

  然后,取上述計算的負對數來計算這一對圖像的損失。

    $l_{i,j}=-log( \frac{exp(sim(z_i,z_j)/\tau )}{ {\textstyle \sum_{k=1}^{2N}}1_{[k\ne i]} \ exp(sim(z_i,z_j)/\tau)  } )$

      

   圖像位置互換,再次計算同一對圖像的損失。

      

  計算 $Batch size N=2$ 的所有配對的損失並取平均值。
    $L = \frac{1}{ 2N } \sum \limits _{k=1}^{N} [l(2k-1, 2k) + l(2k, 2k-1)]$

      

  最后,更新網絡 $f(\cdot )$ 和 $g$ 以及最小化 $L$。

3.2.大批量訓練

  本文將訓練批次大小 $N$ 從 $256$ 改變到 $8192$。$8192$ 的批次大小提供了來自兩個增強視圖的 $2$ 個正示例 $16382$ 個負示例。大批量訓練可能不穩定,為了穩定訓練,我們對所有批次大小使用 $LARS$優化器。我們使用 $Cloud \ TPU$ 訓練我們的模型,根據批量大小使用 $32$ 到 $128$ 個內核,$2 $ 全局 $BN$。

4 對比表示學習的數據增強

  數據增強定義了預測任務。雖然數據增強已廣泛用於有監督和無監督的表示學習,但它並未被視為定義對比預測任務的系統方法,許多現有方法通過改變架構來定義對比預測任務。Bachman 等人通過約束網絡架構中的感受野來實現全局到局部的視圖預測,而 Oord 等人則通過約束網絡架構中的感受野來實現全局到局部的視圖預測。赫納夫等人通過固定的圖像分割過程和上下文聚合網絡實現相鄰視圖預測。我們表明,可以通過對目標圖像執行簡單的隨機裁剪(調整大小)來避免這種復雜性,這創建了一系列包含上述兩個的預測任務。這種簡單的設計選擇方便地將預測任務與其他組件(如神經網絡架構)分離,可以通過擴展增強系列並隨機組合它們來定義更廣泛的對比預測任務。

4.1 數據增強操作的組合對於學習良好的表示至關重要

  為了系統地研究數據增強的影響,本文考慮了幾種常見的增強。 一種類型的增強涉及數據的空間/幾何變換,如裁剪和調整大小、旋轉和剪切。 另一種類型的增強涉及外觀變換,例如顏色失真(包括顏色下降、亮度、對比度、飽和度、色調)、高斯模糊和 Sobel 過濾。 下圖可視化了我們在這項工作中研究的增強。

      

  為了解單個數據增強的影響和增強組合的重要性,本文研究了我們的框架在單獨或成對應用增強時的性能。 由於 ImageNet 圖像大小不同,本文總是應用裁剪調整圖像大小,這使得在沒有裁剪的情況下很難研究其他增強。 為了消除這種混淆,我們考慮了這種消融的非對稱數據轉換設置。 具體來說,我們總是首先隨機裁剪圖像並將它們調整為相同的分辨率,然后只將目標轉換應用於圖 2 中框架的一個分支,而將另一個分支作為身份(即 $t (x_i ) = x_i )$。

  如下圖顯示了數據增強操作單獨組合變換下的線性評估結果(linear evaluation result)。觀察到,即使模型幾乎可以完美地識別對比任務中的正對,也沒有單一的轉換足以學習好的表示。對組合進行增強時,對比預測任務變得更加困難,但表示質量顯着提高。

      

  一種增強組合脫穎而出:隨機裁剪隨機顏色失真 (random cropping and random color distortion)。推測僅使用隨機裁剪作為數據增強時的一個嚴重問題是圖像中的大多數 $patch$ 共享相似的顏色分布。下圖顯示單獨的顏色直方圖就足以區分圖像。神經網絡可以利用這個捷徑來解決預測任務。因此,為了學習可概括的特征,將裁剪與顏色失真組合起來至關重要。

      

  PS:顏色直方圖

4.2 對比學習需要比監督學習更強的數據增強

  為了進一步證明顏色增強的重要性,本文調整了顏色增強的強度,如下表所示。更強的顏色增強顯着改善了學習的無監督模型的線性評估。 在這種情況下,$AutoAugment$  是一種使用監督學習發現的復雜增強策略,其效果並不比簡單裁剪+(更強)顏色失真( simple cropping+ (stronger) color distortion) 更好。 當使用相同的增強集訓練監督模型時,觀察到更強的顏色增強不會改善甚至損害它們的性能。 因此,我們的實驗表明,與監督學習相比,無監督的對比學習受益於更強的(顏色)數據增強 

      

  PS:SimCLR中的無監督ResNet-50與監督ResNet-50。

5 編碼器和投影頭的架構

5.1 無監督的對比學習從更大的模型中獲益更多

  如圖所示,增加深度和寬度都可以提高性能。 雖然類似的發現適用於監督學習,但我們發現監督模型和在無監督模型上訓練的線性分類器之間的差距隨着模型大小的增加而縮小,表明無監督學習從更大的模型中受益比其監督對應物更多。

      

  PS:在線性分類器中比較監督學習和無監督學習。

5.2 非線性投影頭提高了之前圖層的表示質量

  研究投影頭的重要性,即 $g(h)$。 下圖顯示了使用三種不同的頭部架構的線性評估結果:(1)身份映射(identity mapping);(2)線性投影(Linear projection);(3)非線性投影(Non-linear projection)。觀察到非線性投影比線性投影(+3%)好,比沒有投影(>10%)好得多。 當使用投影頭時,無論輸出尺寸如何,都會觀察到類似的結果。 此外,即使使用非線性投影,投影頭之前的層 $h$ 仍然比之后的層 $z = g(h)$  好得多(> 10%),表明投影頭之前的隱藏層是 比之后的層更好的表示。

      

  PS:橫坐標表示 $z$ 的維度。

  本文推測在非線性投影之前使用表示的重要性是由於對比損失引起的信息損失。 特別是,$z=g(h)$ 被訓練為對數據變換保持不變。 因此,$g$ 可能刪除對下游任務有用的信息,例如對象的顏色或方向。 通過利用非線性變換 $g(·)$,可以形成和保持更多的信息。 為了驗證這個假設,使用 $h$  和 $g(h)$  來學習預測在預訓練期間應用的變換。 這里我們設置 $g(h)=W^{(2)}\sigma (W^{(1)}H)$,具有相同的輸入和輸出維度 (即 2048)

      

  PS:在不同的表示上訓練額外的 MLP來預測轉換的准確性。

  本文驗證了一個猜想:$h$ 中含有更多的信息,遠多於 $g(h)$ 。用 $h$ 和$g(h)$ 來衡量一個圖像做了什么工作,圖中通過分類任務使用表示 $h$ 或 $g(h)$比較 ,得出兩者的准確性。准確性越高說明含有原始數據信息越多,通過對比可以發現使用表示 $h$ 的准確性遠高於使用表示 $g(h)$。說明較表示 $g(h)$, 表示$h$含有更多的信息。

6 損失函數和批量大小

6.1 具有可調溫度的歸一化交叉熵損失比替代方案效果更好

  本文將 $NT-Xent$ 損失與其他常用的對比損失函數進行比較,例如邏輯損失邊際損失

  為了使比較公平,對所有損失函數使用相同的 $l_2$ 標准化$(l_2 \  normalization)$方法,並調整超參數,並報告它們的最佳結果。下表顯示,雖然 $(semi-hard negative mining)$ 有幫助,但最佳結果是仍然比我們默認的 $NT-Xent$ 損失更糟糕。

      

  PS:使用不同損失函數訓練的模型的線性評估(top-1)。 “sh”表示使用半硬負挖掘。

  接下來測試 $l_2$ 標准化$(l_2 \ \  normalization)$(即余弦相似度與點積)和溫度 $\tau $ 在我們默認的 NT-Xent 損失中的重要性下表顯示,如果沒有標准化和適當的溫度縮放,性能會明顯變差。 如果沒有 $l_2$ 標准化,對比任務的准確性更高,但在線性評估下得到的表示更差。

      

6.2 對比學習受益於更大的批量和更長的訓練時間

   下圖顯示針對不同時期 (epoch) 數訓練模型時批量大小的影響發現當訓練時期 (epoch)的數量很少(例如 100 個時期 (epoch))時,較大的批次大小 (batch sizes) 比較小的批次具有顯着的優勢隨着更多的訓練步驟/時期,不同批次大小之間的差距會減少或消失,前提是批次是隨機重新采樣的。與監督學習相反,在對比學習中,更大的批次大小提供更多的負樣本,促進收斂(即,對於給定的准確度,采用更少的時期和步驟)。 訓練時間越長,也會提供更多的負面例子,從而改善結果。 

      

  PS:線性評估模型(ResNet50)在不同batch size 和epoch下的准確性。

7 與最先進技術的比較

  本文在 $3$ 個不同的隱藏層寬度(寬度乘數為 $1×$、$2×$ 和 $4×$)中使用 $ResNet-50$。 為了更好的收斂,模型訓練了 $1000$ 個 $epoch$。

  下表將我們的結果與之前的方法進行了比較在線性評估比較。 與以前需要專門設計的架構的方法相比,我們能夠使用標准網絡獲得更好的結果。 使用我們的 $ResNet-50 (4x) $獲得的最佳結果可以匹配監督預訓練的 $ResNet-50$(前文所提)。

       

  PS:線性分類任務中的比較結果

  半監督學習在沒有正則化的情況下對標記數據的整個基礎網絡進行微調。 下表顯示了我們的結果與最近的方法的比較。 同樣,我們的方法顯着改進了 1% 和 10% 的標簽。 

      

結論

  因此,SimCLR 提供了一個強大的框架,可以在這個方向上進行進一步的研究,並改善計算機視覺的自監督學習狀態。

參考文獻

1.圖解SimCLR框架,用對比學習得到一個好的視覺預訓練模型

2.The Illustrated SimCLR Framework


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM