【煉丹技巧】
在深度學習中,經常會使用EMA(指數移動平均)這個方法對模型的參數做平均,以求提高測試指標並增加模型魯棒。
今天瓦礫准備介紹一下EMA以及它的Pytorch實現代碼。
EMA的定義
指數移動平均(Exponential Moving Average)也叫權重移動平均(Weighted Moving Average),是一種給予近期數據更高權重的平均方法。
假設我們有n個數據: ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUIlNUN0aGV0YV8xJTJDKyU1Q3RoZXRhXzIlMkMrLi4uJTJDKyU1Q3RoZXRhX24lNUQ=.png)
- 普通的平均數:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNvdmVybGluZSU3QnYlN0QlM0QlNUNmcmFjJTdCMSU3RCU3Qm4lN0QlNUNzdW1fJTdCaSUzRDElN0QlNUVuKyU1Q3RoZXRhX2k=.png)
- EMA:
,其中,
表示前
條的平均值 (
),
是加權權重值 (一般設為0.9-0.999)。
Andrew Ng在Course 2 Improving Deep Neural Networks中講到,EMA可以近似看成過去
個時刻
值的平均。
普通的過去
時刻的平均是這樣的:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD12X3QrJTNEJTVDZnJhYyU3QiUyOG4tMSUyOSU1Q2Nkb3Qrdl8lN0J0LTElN0QlMkIlNUN0aGV0YV90JTdEJTdCbiU3RCs=.png)
類比EMA,可以發現當
時,兩式形式上相等。需要注意的是,兩個平均並不是嚴格相等的,這里只是為了幫助理解。
實際上,EMA計算時,過去
個時刻之前的數值平均會decay到
的加權比例,證明如下。
如果將這里的
展開,可以得到:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD12X3QrJTNEKyU1Q2FscGhhJTVFbit2XyU3QnQtbiU3RCslMkIrJTI4MS0lNUNhbHBoYSUyOSUyOCU1Q2FscGhhJTVFJTdCbi0xJTdEJTVDdGhldGFfJTdCdC1uJTJCMSU3RCUyQisuLi4rJTJCJTVDYWxwaGElNUUwJTVDdGhldGFfdCUyOSs=.png)
其中,
,代入可以得到
。
在深度學習的優化中的EMA
上面講的是廣義的ema定義和計算方法,特別的,在深度學習的優化過程中,
是t時刻的模型權重weights,
是t時刻的影子權重(shadow weights)。在梯度下降的過程中,會一直維護着這個影子權重,但是這個影子權重並不會參與訓練。基本的假設是,模型權重在最后的n步內,會在實際的最優點處抖動,所以我們取最后n步的平均,能使得模型更加的魯棒。
EMA的偏差修正
實際使用中,如果令
,且步數較少,ema的計算結果會有一定偏差。
理想的平均是綠色的,因為初始值為0,所以得到的是紫色的。
因此可以加一個偏差修正(bias correction):
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD12X3QrJTNEKyU1Q2ZyYWMlN0J2X3QlN0QlN0IxLSU1Q2JldGElNUV0JTdEKw==.png)
顯然,當t很大時,修正近似於1。
EMA為什么有效
網上大多數介紹EMA的博客,在介紹其為何有效的時候,只做了一些直覺上的解釋,缺少嚴謹的推理,瓦礫在這補充一下,不喜歡看公式的讀者可以跳過。
令第n時刻的模型權重(weights)為
,梯度為
,可得:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0rJTVDYmVnaW4lN0JhbGlnbiU3RCslNUN0aGV0YV9uKyUyNiUzRCslNUN0aGV0YV8lN0JuLTElN0QtZ18lN0JuLTElN0QrJTVDJTVDKyUyNiUzRCU1Q3RoZXRhXyU3Qm4tMiU3RC1nXyU3Qm4tMSU3RC1nXyU3Qm4tMiU3RCslNUMlNUMrJTI2JTNEKy4uLislNUMlNUMrJTI2JTNEKyU1Q3RoZXRhXzEtJTVDc3VtXyU3QmklM0QxJTdEJTVFJTdCbi0xJTdEZ19pKyU1Q2VuZCU3QmFsaWduJTdE.png)
令第n時刻EMA的影子權重為
,可得:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduJTdEK3ZfbislMjYlM0QrJTVDYWxwaGErdl8lN0JuLTElN0QlMkIlMjgxLSU1Q2FscGhhJTI5JTVDdGhldGFfbislNUMlNUMrJTI2JTNEKyU1Q2FscGhhKyUyOCU1Q2FscGhhK3ZfJTdCbi0yJTdEJTJCJTI4MS0lNUNhbHBoYSUyOSU1Q3RoZXRhXyU3Qm4tMSU3RCUyOSUyQiUyODEtJTVDYWxwaGElMjklNUN0aGV0YV9uKyU1QyU1QyslMjYlM0QrLi4uKyU1QyU1QyslMjYlM0QrJTVDYWxwaGElNUVuK3ZfMCUyQiUyODEtJTVDYWxwaGElMjklMjglNUN0aGV0YV9uJTJCJTVDYWxwaGElNUN0aGV0YV8lN0JuLTElN0QlMkIlNUNhbHBoYSU1RTIlNUN0aGV0YV8lN0JuLTIlN0QlMkIuLi4lMkIlNUNhbHBoYSU1RSU3Qm4tMSU3RCU1Q3RoZXRhXyU3QjElN0QlMjkrJTVDZW5kJTdCYWxpZ24lN0Q=.png)
代入上面
的表達,令
展開上面的公式,可得:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduJTdEK3ZfbislMjYlM0QrJTVDYWxwaGElNUVuK3ZfMCUyQiUyODEtJTVDYWxwaGElMjklMjglNUN0aGV0YV9uJTJCJTVDYWxwaGElNUN0aGV0YV8lN0JuLTElN0QlMkIlNUNhbHBoYSU1RTIlNUN0aGV0YV8lN0JuLTIlN0QlMkIuLi4lMkIlNUNhbHBoYSU1RSU3Qm4tMSU3RCU1Q3RoZXRhXyU3QjElN0QlMjklNUMlNUMrJTI2JTNEKyU1Q2FscGhhJTVFbit2XzAlMkIlMjgxLSU1Q2FscGhhJTI5JTI4JTVDdGhldGFfMS0lNUNzdW1fJTdCaSUzRDElN0QlNUUlN0JuLTElN0RnX2klMkIlNUNhbHBoYSUyOCU1Q3RoZXRhXzEtJTVDc3VtXyU3QmklM0QxJTdEJTVFJTdCbi0yJTdEZ19pJTI5JTJCLi4uJTJCKyU1Q2FscGhhJTVFJTdCbi0yJTdEJTI4JTVDdGhldGFfMS0lNUNzdW1fJTdCaSUzRDElN0QlNUUlN0IxJTdEZ19pJTI5JTJCJTVDYWxwaGElNUUlN0JuLTElN0QlNUN0aGV0YV8lN0IxJTdEJTI5JTVDJTVDKyUyNiUzRCslNUNhbHBoYSU1RW4rdl8wJTJCJTI4MS0lNUNhbHBoYSUyOSUyOCU1Q2ZyYWMlN0IxLSU1Q2FscGhhJTVFbiU3RCU3QjEtJTVDYWxwaGElN0QlNUN0aGV0YV8xLSU1Q3N1bV8lN0JpJTNEMSU3RCU1RSU3Qm4tMSU3RCU1Q2ZyYWMlN0IxLSU1Q2FscGhhJTVFJTdCbi1pJTdEJTdEJTdCMS0lNUNhbHBoYSU3RGdfaSUyOSslNUMlNUMrJTI2JTNEKyU1Q2FscGhhJTVFbit2XzAlMkIlMjgxLSU1Q2FscGhhJTVFbiUyOSU1Q3RoZXRhXzErLSU1Q3N1bV8lN0JpJTNEMSU3RCU1RSU3Qm4tMSU3RCUyODEtJTVDYWxwaGElNUUlN0JuLWklN0QlMjlnX2klNUMlNUMrJTI2JTNEKyU1Q3RoZXRhXzErLSU1Q3N1bV8lN0JpJTNEMSU3RCU1RSU3Qm4tMSU3RCUyODEtJTVDYWxwaGElNUUlN0JuLWklN0QlMjlnX2krJTVDZW5kJTdCYWxpZ24lN0Q=.png)
對比兩式:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduJTdEKyU1Q3RoZXRhX24rJTI2JTNEKyU1Q3RoZXRhXzEtJTVDc3VtXyU3QmklM0QxJTdEJTVFJTdCbi0xJTdEZ19pKyU1QyU1Qyt2X24rJTI2JTNEKyU1Q3RoZXRhXzErLSU1Q3N1bV8lN0JpJTNEMSU3RCU1RSU3Qm4tMSU3RCUyODEtJTVDYWxwaGElNUUlN0JuLWklN0QlMjlnX2krJTVDZW5kJTdCYWxpZ24lN0Q=.png)
EMA對第i步的梯度下降的步長增加了權重系數
,相當於做了一個learning rate decay。
PyTorch實現
瓦礫看了網上的一些實現,使用起來都不是特別方便,所以自己寫了一個。
class EMA(): def __init__(self, model, decay): self.model = model self.decay = decay self.shadow = {} self.backup = {} def register(self): for name, param in self.model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] self.shadow[name] = new_average.clone() def apply_shadow(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow self.backup[name] = param.data param.data = self.shadow[name] def restore(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.backup param.data = self.backup[name] self.backup = {} # 初始化 ema = EMA(model, 0.999) ema.register() # 訓練過程中,更新完參數后,同步update shadow weights def train(): optimizer.step() ema.update() # eval前,apply shadow weights;eval之后,恢復原來模型的參數 def 