【機器學習的Tricks】隨機權值平均優化器swa與pseudo-label偽標簽


文章來自公眾號【機器學習煉丹術】

1 stochastic weight averaging(swa)

  • 隨機權值平均
  • 這是一種全新的優化器,目前常見的有SGB,ADAM,

【概述】:這是一種通過梯度下降改善深度學習泛化能力的方法,而且不會要求額外的計算量,可以用到Pytorch的優化器中。

隨機權重平均和隨機梯度下降SGD相似,所以我一般吧SWa看成SGD的進階版本。

1.1 原理與算法

swa算法流程:

【怎么理解】:

  • \(w_{swa}\)做了一個周期為c的滑動平均。每迭代c次,就會對這個\(w_{swa}\)做一次滑動平均。其他的時間使用SGD進行更新。
  • 簡單的說,整個流程是模型初始化參數之后,使用SGD進行梯度下降,迭代了c個epoch之后,將模型的參數用加權平均,得到\(w_{SWA}\),然后現在模型的參數就是\(w_{SWA}\),然后再用SGD去梯度下降c個epoch,然后再加權平均出來一個新的\(w_{SWA}\).

SWA加入了周期性滑動平均來限制權重的變化,解決了傳統SGD在反向過程中的權重震盪問題。SGD是依靠當前的batch數據進行更新,尋找隨機梯度下降隨機尋找的樣本的梯度下降方向很可能並不是我們想要的方向。

論文中給出了一個圖片:

  • 綠線是恆定學習率的SGD,效果並不好,直到SGD在訓練的過程中所見了學習率,才可以得到一個收斂的結果;
  • 而使用Stochastic weight averaging可以在學習率恆定的情況下,快速收斂,而且過程平穩。

1.2 python與實現

這里講如何在pytorch深度學習框架中加入swa作為優化器:

from torchcontrib.optim import SWA

# training loop
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
opt = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
for _ in range(100):
     opt.zero_grad()
     loss_fn(model(input), target).backward()
     opt.step()
opt.swap_swa_sgd()

如果使用了swa的話,那么lr_schedule這個方法就不需要在使用了,非常的方便。

【關於參數】:
使用swa的時候,就直接通過

torchcontrib.optim.SWA(base_opt,swa_start,swa_greq,swa_lr)

來封裝原來的優化器。

  • swa_start:是一個整數,表示經過swa_start個steps后,將學習率切換為固定值swa_lr。(在swa_start之前的step中,lr是0.1,在10個steps之后,lr變成0.05)
  • swa_freq:在swa_freq個step優化之后,會將對應的權重加到swa滑動平均的結果上,相當於算法中的c;
  • 使用opt.swap_swa_sgd()之后,可以將模型的權重替換為swa的滑動平均的權重。

1.3 關於BN

這里有一個問題就是在BatchNorm層訓練的時候,BN層中也是有兩個訓練參數的,使用\(w_{swa}\)重置了模型參數,但是並沒有更新BN層的參數,所以如果有bn層的話,還需要加上:

opt.bn_update(train_loader,model)

2 Pseudo-Label

  • 偽標簽
  • 這是一種半監督的方法。其實非常簡單,就是對於未標記的數據,許納澤預測概率最大的標記作為該樣本的pseudo-label,然后給未標記數據設置一個權重,在訓練過程中慢慢增加未標記數據的權重。

這個方法的loss如下:

非常好理解了,前面一項就是訓練集的loss,后面是測試集的loss,然后用一個\(\alpha(t)\)來做權重。

然后這個\(\alpha(t)\)就是隨着訓練的迭代次數增加而慢慢的線性增加(如果按照原來的論文中的描述):

【一些關於pseudo-label的雜談】

這個方法提出在2013年,然后再2015年作者用entropy信息熵來證明這個方法的有效性。但是證明過程較為牽強。這個偽標簽我在2017年的一個項目中想到了,但是不知道可行不可行自己當時也無法進行證明,就作罷了,沒想到現在看到同樣的方法在2013年就提出來了。有點五味雜陳哈哈。

參考文獻:

  1. Izmailov, Pavel, et al. "Averaging weights leads to wider optima and better generalization." arXiv preprint arXiv:1803.05407 (2018).

  2. Grandvalet, Yves, and Yoshua Bengio. "Semi-supervised learning by entropy minimization." Advances in neural information processing systems. 2005.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM