比較Adam 和Adamw


引用自: https://www.lizenghai.com/archives/64931.html

 

AdamW

AdamW是在Adam+L2正則化的基礎上進行改進的算法。
使用Adam優化帶L2正則的損失並不有效。如果引入L2正則項,在計算梯度的時候會加上對正則項求梯度的結果。

那么如果本身比較大的一些權重對應的梯度也會比較大,由於Adam計算步驟中減去項會有除以梯度平方的累積,使得減去項偏小。按常理說,越大的權重應該懲罰越大,但是在Adam並不是這樣。

而權重衰減對所有的權重都是采用相同的系數進行更新,越大的權重顯然懲罰越大。

在常見的深度學習庫中只提供了L2正則,並沒有提供權重衰減的實現。


Adam+L2 VS AdamW

圖片中紅色是傳統的Adam+L2 regularization的方式,綠色是Adam+weightdecay的方式。可以看出兩個方法的區別僅在於“系數乘以上一步參數值“這一項的位置。

再結合代碼來看一下AdamW的具體實現。

以下代碼來自https://github.com/macanv/BERT-BiLSTM-CRF-NER/blob/master/bert_base/bert/optimization.py中的AdamWeightDecayOptimizer中的apply_gradients函數中,BERT中的優化器就是使用這個方法。

在代碼中也做了一些注釋用於對應之前給出的Adam簡化版公式,方便理解。可以看出update += self.weight_decay_rate * param這一句是Adam中沒有的,也就是Adam中綠色的部分對應的代碼,weightdecay這一步是是發生在Adam中需要被更新的參數update計算之后,並且在乘以學習率learning_rate之前,這和圖片中的偽代碼的計算順序是完全一致的。總之一句話,如果使用了weightdecay就不必再使用L2正則化了。

 

   # m = beta1*m + (1-beta1)*dx
      next_m = (tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
      # v = beta2*v + (1-beta2)*(dx**2)
      next_v = (tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, tf.square(grad)))
      # m / (np.sqrt(v) + eps)
      update = next_m / (tf.sqrt(next_v) + self.epsilon)
      # Just adding the square of the weights to the loss function is *not*
      # the correct way of using L2 regularization/weight decay with Adam,
      # since that will interact with the m and v parameters in strange ways.
      #
      # Instead we want ot decay the weights in a manner that doesn't interact
      # with the m/v parameters. This is equivalent to adding the square
      # of the weights to the loss with plain (non-momentum) SGD.
      if self._do_use_weight_decay(param_name):
        update += self.weight_decay_rate * param
      update_with_lr = self.learning_rate * update
      # x += - learning_rate * m / (np.sqrt(v) + eps)
      next_param = param - update_with_lr

原有的英文注釋中也解釋了Adam和傳統Adam+L2正則化的差異,好了到這里應該能理解Adam了,並且也能理解AdamW在Adam上的改進了。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM