【煉丹技巧】
在深度學習中,經常會使用EMA(指數移動平均)這個方法對模型的參數做平均,以求提高測試指標並增加模型魯棒。
今天瓦礫准備介紹一下EMA以及它的Pytorch實現代碼。
EMA的定義
指數移動平均(Exponential Moving Average)也叫權重移動平均(Weighted Moving Average),是一種給予近期數據更高權重的平均方法。
假設我們有n個數據:
- 普通的平均數:
- EMA:
,其中,
表示前
條的平均值 (
),
是加權權重值 (一般設為0.9-0.999)。
Andrew Ng在Course 2 Improving Deep Neural Networks中講到,EMA可以近似看成過去 個時刻
值的平均。
普通的過去 時刻的平均是這樣的:
類比EMA,可以發現當 時,兩式形式上相等。需要注意的是,兩個平均並不是嚴格相等的,這里只是為了幫助理解。
實際上,EMA計算時,過去 個時刻之前的數值平均會decay到
的加權比例,證明如下。
如果將這里的 展開,可以得到:
其中, ,代入可以得到
。
在深度學習的優化中的EMA
上面講的是廣義的ema定義和計算方法,特別的,在深度學習的優化過程中, 是t時刻的模型權重weights,
是t時刻的影子權重(shadow weights)。在梯度下降的過程中,會一直維護着這個影子權重,但是這個影子權重並不會參與訓練。基本的假設是,模型權重在最后的n步內,會在實際的最優點處抖動,所以我們取最后n步的平均,能使得模型更加的魯棒。
EMA的偏差修正
實際使用中,如果令 ,且步數較少,ema的計算結果會有一定偏差。

理想的平均是綠色的,因為初始值為0,所以得到的是紫色的。
因此可以加一個偏差修正(bias correction):
顯然,當t很大時,修正近似於1。
EMA為什么有效
網上大多數介紹EMA的博客,在介紹其為何有效的時候,只做了一些直覺上的解釋,缺少嚴謹的推理,瓦礫在這補充一下,不喜歡看公式的讀者可以跳過。
令第n時刻的模型權重(weights)為 ,梯度為
,可得:
令第n時刻EMA的影子權重為 ,可得:
代入上面 的表達,令
展開上面的公式,可得:
對比兩式:
EMA對第i步的梯度下降的步長增加了權重系數 ,相當於做了一個learning rate decay。
PyTorch實現
瓦礫看了網上的一些實現,使用起來都不是特別方便,所以自己寫了一個。
class EMA(): def __init__(self, model, decay): self.model = model self.decay = decay self.shadow = {} self.backup = {} def register(self): for name, param in self.model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] self.shadow[name] = new_average.clone() def apply_shadow(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow self.backup[name] = param.data param.data = self.shadow[name] def restore(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.backup param.data = self.backup[name] self.backup = {} # 初始化 ema = EMA(model, 0.999) ema.register() # 訓練過程中,更新完參數后,同步update shadow weights def train(): optimizer.step() ema.update() # eval前,apply shadow weights;eval之后,恢復原來模型的參數 def