論文鏈接:
前一篇隨筆AI教父的自監督直覺——SimCLR中介紹了自監督任務的一些動機以及Hinton的方法。在這一篇隨筆中,我們來觀摩下MoCo,該方法在整體形式上更加豐富,動機也十分清晰。文章的作者陣容可以說十分華麗,Kaiming He 以及 Ross Girshick 等都是業界元佬。
主干思路提煉
了解文章的方法全貌只需要看偽代碼足矣。文章的偽代碼使用Pytorch形式,非常接地氣。
'''
f_k與f_q是將輸入信息映射到特征空間的網絡,特征空間由一個長度為C的向量表示。
這里的k可以看作模板,q看作查詢元素,每一個輸入未知圖像的特征由f_q提取,
現在給一系列由f_k提取的模板特征(比如狗的特征、貓的特征),就能使用f_q與f_k的度量值來確定f_q是屬於什么。
在早先的比較學習中,f_k與f_q使用的是同一個網絡,這篇文章的創新點就是,將兩者分開,並且兩者的參數更新方式是不同的。
'''
f_k.params = f_q.params # 初始化
for x in loader: # 輸入一個圖像序列x,包含N張圖,沒有標簽
x_q = aug(x) # 用於查詢的圖(數據增強得到)
x_k = aug(x) # 模板圖(數據增強得到),自監督就體現在這里,只有圖x和x的數據增強才被歸為一類
q = f_q.forward(x_q) # 提取查詢特征,輸出NxC
k = f_k.forward(x_k) # 提取模板特征,輸出NxC
# 不使用梯度更新f_k的參數,這是因為文章假設用於提取模板的表示應該是穩定的,不應立即更新
k = k.detach()
# 這里bmm是分批矩陣乘法
l_pos = bmm(q.view(N,1,C), k.view(N,C,1)) # 輸出Nx1,也就是自己與自己的增強圖的特征的匹配度
l_neg = mm(q.view(N,C), queue.view(C,K)) # 輸出Nxk,自己與上一批次所有圖的匹配度(全不匹配)
logits = cat([l_pos, l_neg], dim=1) # 輸出Nx(1+k)
labels = zeros(N)
# NCE損失函數,就是為了保證自己與自己衍生的匹配度輸出越大越好,否則越小越好
loss = CrossEntropyLoss(logits/t, labels)
loss.backward()
update(f_q.params) # f_q使用梯度立即更新
# 由於假設模板特征的表示方法是穩定的,因此它更新得更慢,這里使用動量法更新,相當於做了個濾波。
f_k.params = m*f_k.params+(1-m)*f_q.params
enqueue(queue, k) # 為了生成反例,所以引入了隊列
dequeue(queue)
疑點:為什么矩陣乘法可以算匹配度?
比如輸入有N個樣本每個樣本有C個特征,它可被表示成NxC得矩陣。
現在有一系列模板樣本,比如M個,那么它可以被表示為MxC矩陣。
現在將這兩個矩陣相乘,得到一個NxM得匹配度矩陣,那么矩陣中i-行,j-列得值就對應輸入的第i個樣本與模板的第j個樣本的相關性。
相關性被記作兩個向量的內積,兩向量方向越趨同(在此基礎上模越大),相關性也就越大。
訓練細節
在第一篇文章中的訓練batch開的很大為256,使用八張卡訓練了53個小時,具體細節比較多,復現可參照原論文。
值得注意的是,批正則化層(BN)使用了一種叫做 Shuffing BN 的方法。 (f_q以及f_k都使用了BN層)
Shuffing BN 方法為:在f_k將參數分散到多卡前(分散是Pytorch基本操作)洗牌其樣本順序,然后再前向傳播后整回原狀。
具體來說,在命令 k = f_k.forward(x_k) 執行前后進行這個操作,這樣保證了每次BN所需的統計信息不僅局限於同一張圖的衍生(x_q及x_k的對應項都是由同一張圖衍生出來的,如果BN都在這種相似的分布下采樣一定會出問題)
重要結果
以下結果是在Imagenet-1M下無監督訓練的,並在驗證集上測試分類。測試分類之前,把f_q固定住,外加一個全連接層訓練一小會兒。可以看到得到的特征表示效果還可以。
下面這個對比主要來看MoCo的動量式(因子0.999)更新f_k方法與直接梯度更新f_k(end-to-end)的對比。
以下為其中一項遷移學習效果,在1M對照下看起來並沒有太多的提升。
分析一波
文章主要還是探討了用於產生模板特征的f_k以及用於產生查詢特征的f_q的參數更新方式。其實這個問題在很多工作中都有探討,在GAN中,有的時候可能先fix住判別器然后再fix住生成器這樣的交替方式易於收斂,也比如在強化學習Actor-Crictic中,用於評價當前步驟好壞的網絡以及用於產生決策的網絡也不是同步更新的(往往fix住一個更新另一個)。這里的方法是采用動量方式去更新f_k,使得用於比較的一端更加穩定,會起到更好的效果。