[論文理解] Bootstrap Your Own Latent A New Approach to Self-Supervised Learning


Bootstrap Your Own Latent A New Approach to Self-Supervised Learning

Intro

文章提出一種不需要負樣本來做自監督學習的方法,提出交替更新假說解釋EMA方式更新target network防止collapse的原因,同時用梯度解釋online網絡和target不同構帶來的好處。

Intuitions

傳統基於對比學習的自監督方法都是需要構造負樣本來防止mode collapse,但本文提出的方法卻不需要負樣本。試想一下,如果不使用負樣本來防止mode collapse,那么一種比較直觀的方式是使用一個teacher網絡來指導需要訓練的網絡,假設teacher網絡訓練的足夠好(比如全監督訓練),是非collapse的網絡,那么理論上訓練一個目標網絡和teacher網絡對同一輸入的輸出一致,這樣就可以了;但是這樣做是違背自監督學習的初衷的,自監督學習不應依賴預訓練網絡,否則就失去了意義。於是一個簡單的想法,假如去隨機初始化一個teacher網絡,並且認為這個teacher網絡經過隨機初始化是非collapse的網絡,由於隨機初始化的網絡具有一定的圖像先驗,所以理論上也不會導致student網絡學習到collapse的結果,但可能效果會差一點。

基於上述想法,文章做了這樣一個實驗,拿隨機初始化的網絡作為teacher網絡並固定其權重,用目標網絡學習該網絡的輸出,但這樣的結果是雖然防止了collapse,但是特征表示能力不強;但是經過這樣訓練的student網絡效果卻比原始的teacher網絡效果要好。

那么首先問題就來了,理論上student網絡的上限是teacher網絡呀,為什么student網絡最終效果比teacher網絡還要好?

這里的提升其實來自於增廣,teacher和student網絡的輸入並不是完全一致的,而是兩次不同的增廣。

驗證了上述想法,文章自然想去通過在訓練過程中同時“訓練”teacher網絡來達到試student網絡得到提升的效果,因為teacher網絡不“訓練”就能給student網絡帶來提升,而teacher網絡“訓練”到最好的情況是全監督的預訓練,效果顯然是最好的,在這之間如果能夠每次提升student網絡之后,利用student網絡學習到的知識來提升一下teacher網絡,那么這樣是直觀能夠提升最終目標網絡效果的。

Method

文章利用student網絡提升teacher網絡的方式是另teacher網絡以以student網絡EMA的方式更新來實現的,大體框架如下:

過程不難理解,但需要注意的是prediction部分和stop gradient部分,這里teacher網絡和student網絡backbone結構都是一致的,唯一不同是student網絡多了一個prediction部分,那么問題來了,這一結構是否不可或缺,他起到了什么作用?第二點比較好理解,stop gradient部分即不用梯度方式更新teacher網絡,而是使用EMA方式,那么第二個問題就是,為什么使用EMA的方式更新teacher網絡能夠方式collapse?

看看文章對這兩個問題的解釋:

如果teacher網絡的更新過程是gradient descent的,那么顯然他也會陷入collpase,但是teacher網絡的更新方式是EMA,而非gradient descent,文章假定並沒有一種gradient descent的loss能夠同時去更新兩個網絡的參數使得loss最小,那么這樣的loss更新方式其實更類似於GAN的方式,交替更新(這里只更新了一步,第一步並沒有更新),按照這樣的理解,交替過程的第一步,固定target使q達到最優,則有:

\[q^{\star} \triangleq \underset{q}{\arg \min } \mathbb{E}\left[\left\|q\left(z_{\theta}\right)-z_{\xi}^{\prime}\right\|_{2}^{2}\right], \quad \text { where } \quad q^{\star}\left(z_{\theta}\right)=\mathbb{E}\left[z_{\xi}^{\prime} \mid z_{\theta}\right] \]

這一步其實並沒有對參數進行更新,簡單的理解是固定backbone,求解一個最優的q使得q與target的誤差最小,但並不更新q,相當於得到了當前backbone參數下最優的投影,其實也比較合理,畢竟student網絡和target網絡有可能差異較大,通過這一層投影是可以拉近兩者距離的(通過最小化q的方式)。

這時候更新整體參數\(\theta\)(包括backbone和predictor,因為第一步並沒有更新參數),則要求參數\(\theta\)的梯度:

\[\nabla_{\theta} \mathbb{E}\left[\left\|q^{\star}\left(z_{\theta}\right)-z_{\xi}^{\prime}\right\|_{2}^{2}\right]=\nabla_{\theta} \mathbb{E}\left[\left\|\mathbb{E}\left[z_{\xi}^{\prime} \mid z_{\theta}\right]-z_{\xi}^{\prime}\right\|_{2}^{2}\right]=\nabla_{\theta} \mathbb{E}\left[\sum_{i} \operatorname{Var}\left(z_{\xi, i}^{\prime} \mid z_{\theta}\right)\right] \]

從這倆公式可以看出其實本文的假設是backbone的參數更新和predictor的參數更新可以等價為兩個參數更新的EM過程,先求固定backbone參數下predictor q的最優參數,這時候q取得最佳值,然后利用最佳的q去更新backbone和q的參數。但是該過程並沒有詳細的數學證明。

因此,按照文章假設,該過程對參數的求導等價於對條件分布的方差的求導。需要注意的是,對於任意隨機變量X、Y、Z,有\(Var(X|Y) \geq Var(X|Y,Z)\),這里X是target projection,Y是online projection,Z是q作用后的隨機變量,引入Z之后會讓方差變小,也就是會讓上面的梯度變小,所以解釋了為什么本文要引入predictor q(也和上面的理解不謀而合,進一步拉近online分布和target分布的距離)。

對於collapse的分布z,有不等式\(Var(z_{\xi}^{'}|z_\theta) \leq Var(z_{\xi}^{'}|c)\),其中c為常量(collapse分布),梯度下降只會使方差下降,因此能夠找到某一參數\(\theta\)使得其不為常量分布,進而避免了collapse。(感覺有一點牽強,因為不是嚴格小於)。

總結一下,上面兩個問題的解釋分別是,引入predictor q會進一步拉近online分布和target分布的距離進而使得梯度更小,更容易優化;而使用EMA方式更新target網絡作者提出了交替過程假說來解釋這樣更新的效果,從梯度下降角度解釋了常量分布的條件方差一定比非常量分布的條件方差大,從而說明本文的方法可以避免collapse。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM