Contrastive Predictive Coding(CPC)


Representation Learning with Contrastive Predictive Coding(CPC)

論文鏈接

代碼鏈接

為什么叫Contrastive Predictive Coding
也就是說當前輸入\(x_i\)和經過encoder學得的latent vector \(y_i\)配對,送入discriminator告訴ta這是正樣本

Motivation

下圖中\(g_{enc}\)是nonlinear encoder,\(z_t\)是編碼后的特征latent vector

\(g_{ar}\) is an autoregressive model that summarizes all \(z \le t\) in the latent space and produces a context latent representation\(c_t=g_{ar}(z \le t)\)

也就是說將\(z_t\)以及之前所有時刻的相關信息輸入到一個自回歸模型中,生成當前時刻的上下文表示\(c_t\)

Predict the future->Good Representation->Mutual Information
也就是說想要得到好的預測我們要最大化input \(x_t\)和context \(c_t\)互信息(Mutual Information),即盡可能多的用\(c_t\)去表達原始信號\(x\)

\[I(x;c)=\sum_{x,c}p(x,c)\log \frac {p(x|c)} {p(x)} \]

所以論文不采用生成模型\(p(x_{t+k}|c)\)進行預測,而是最大化mutual information使得預測的\(\tilde z_{t+k}\)與真實的\(z_{t+k}\)盡可能接近
但是\(p(x,c)\)無法直接獲得,所以要提出一個模型去近似未來真實數據與隨機采樣數據的概率之比

\[f_k(x_{t+k},c_t)\propto \frac {p(x_{t+k},c_t)} {p(x_{t+k})} \]

a simple log-bilinear model

\[f_k(x_{t+k},c_t)=\exp \left(z_{t+k}^TW_kc_t \right) \]

\(W\)的下標\(k\)是指預測未來不同時刻時要用到不同的參數,\(z_{t+k}^T\)是真實值,用向量內積來衡量相似度

Method

InfoNCE Loss

一個batch中的N個隨機樣本包括

  • 一個正樣本從\(p(x_{t+k}|c_t)\)中采樣:來自與當前的上下文\(c_t\)相隔\(k\)個步長的樣本
  • 剩余\(N-1\)個負樣本從與\(c_t\)無關的\(p(x_{t+k})\)分布中取得:來自從序列隨機選取的樣本

\((x_{t+k},c_t)\)可以看成正樣本對,\((x_j,c_t)\)可以看成負樣本對

The loss is the categorical cross-entropy of classifying the positive sample correctly.

InfoNCE Loss定義如下,相當於一個多分類交叉熵損失

優化該損失函數,應使分子盡可能大,也就是正樣本對之間的互信息更大,負樣本對之間的互信息更小。優化該損失,實際上就是最大化\(x_{t+k}\)\(c_t\)間的互信息
這里有個問題,負樣本是隨機采樣的,那么負樣本中也可能有與要預測的結果相關的樣本信息

問題解決:在具體實踐時,常常在對一個batch進行訓練時,把當前sample的\((x_{t+k}^i,c_t^i)\)當作正樣本對,把batch中其他samples和當前sample的預測值配對\((x_{t+k}^j,c_t^i)\)來計算

Mutual Information Estimation

上述損失函數的optimal情況:假設\(x_i\)\(c_t\)的預測結果,即正樣本,那么\(x_i\)從條件分布\(p(x_{t+k}|c_t)\)中采樣出來的概率如下,也就是f的最優解

可以看出\(f_k(x_{t+k},c_t)\)確實與\(\displaystyle \frac {p(x_{t+k}|c_t)} {p(x_{t+k})}\)成比例

於是把\(\displaystyle \frac {p(x_{t+k}|c_t)} {p(x_{t+k})}\)帶入到InfoNCE Loss中

證明了最小化InfoNCE也就是最大化互信息

Experiments

圖像分類上的應用,用7x7個64x64大小小的grid在256x256的圖上去crop,crop間有50%重疊,每個crop送入encoder(ResNet-101),把前幾個patch作為輸入,預測后面的patch

參考博客

參考視頻


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM