PyTorch之BN核心參數詳解
原始文檔:https://www.yuque.com/lart/ugkv9f/qoatss
affine
初始化時修改
affine 設為 True 時,BatchNorm 層才會學習參數 gamma 和 beta,否則不包含這兩個變量,變量名是 weight 和 bias。
.train()
- 如果
affine==True
,則對歸一化后的 batch 進行仿射變換,即乘以模塊內部的 weight(初值是[1., 1., 1., 1.])然后加上模塊內部的 bias(初值是[0., 0., 0., 0.]),這兩個變量會在反向傳播時得到更新。 - 如果
affine==False
,則 BatchNorm 中不含有 weight 和 bias 兩個變量,什么都都不做。
.eval()
- 如果
affine==True
,則對歸一化后的 batch 進行放射變換,即乘以模塊內部的 weight 然后加上模塊內部的 bias,這兩個變量都是網絡訓練時學習到的。 - 如果
affine==False
,則 BatchNorm 中不含有 weight 和 bias 兩個變量,什么都不做。
修改實例屬性
無影響,仍按照初始化時的設定。
track_running_stats
由於 BN 的前向傳播中涉及到了該屬性,所以實例屬性的修改會影響最終的計算過程。
class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm"""
_version = 2
__constants__ = ['track_running_stats', 'momentum', 'eps',
'num_features', 'affine']
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
# WARNING: weight and bias purposely not defined here.
# See https://github.com/pytorch/pytorch/issues/39670
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True
) -> None:
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
...
class _BatchNorm(_NormBase):
...
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None: # type: ignore
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
可以看到這里的bn_training控制的是,數據運算使用當前batch計算得到的統計量(True)
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
這里強調的是統計量buffer的使用條件(self.running_mean, self.running_var)
- training==True and track_running_stats==False, 這些屬性被傳入F.batch_norm中時,均替換為None
- training==True and track_running_stats==True, 會使用這些屬性中存放的內容
- training==False and track_running_stats==True, 會使用這些屬性中存放的內容
- training==False and track_running_stats==False, 會使用這些屬性中存放的內容
"""
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
.train()
注意代碼中的注釋:Buffers are only updated if they are to be tracked and we are in training mode. 即僅當為訓練模式且track_running_stats==True
時會更新這些統計量 buffer。
另外,此時self.training==True
。bn_training=True
。
track_running_stats==True
BatchNorm 層會統計全局均值 running_mean 和方差 running_var,而對 batch 歸一化時,僅使用當前 batch 的統計量。
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
使用 momentum 更新模塊內部的 running_mean。
- 如果 momentum 是 None,那么就是用累計移動平均(這里會使用屬性
self.num_batches_tracked
來統計已經經過的 batch 數量),否則就使用指數移動平均(使用 momentum 作為系數)。二者的更新公式基本框架是一樣的:\(x_{new}=(1 - factor) \times x_{cur} + factor \times x_{batch}\)
,只是具體的 \(factor\) 有所不同。- \(x_{new}\) 代表更新后的 running_mean 和 running_var;
- \(x_{cur}\) 表示更新前的running_mean和running_var;
- $x_{batch}$ 表示當前 batch 的均值和無偏樣本方差。
- 累計移動平均的更新中 \(factor=1/num\_batches\_tracked\)。
- 指數移動平均的更新公式是 \(factor=momentum\)。
修改實例屬性
如果設置.track_running_stats==False
,此時self.num_batches_tracked
不會更新,而且exponential_average_factor
也不會被重新調整。
而由於:
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
且此時self.training==True
,並且self.track_running_stats==False
,所以送入F.batch_norm
的self.running_mean&self.running_var
兩個參數都是 None。
也就是說,此時和直接在初始化中設置**track_running_stats==False**
是一樣的效果。
但是要小心這里的~~exponential_average_factor~~
的變化。不過由於通常我們初始化 BN 時,僅僅會送入~~num_features~~
,所以默認會使用~~exponential_average_factor = self.momentum~~
來構造指數移動平均更新運行時統計量。(此時exponential_average_factor
不會發揮作用)
track_running_stats==False
則 BatchNorm 中不含有 running_mean 和 running_var 兩個變量,也就是僅僅使用當前 batch 的統計量來歸一化 batch。
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
修改實例屬性
如果設置.track_running_stats==True
,此時self.num_batches_tracked
仍然不會更新,因為其初始值是 None。
整體來看,這樣的修改並沒有實際影響。
.eval()
此時self.training==False
。
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
此時送入F.batch_norm
的兩個統計量 buffer 和初始化時的結果是一致的。
track_running_stats==True
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
此時bn_training = (self.running_mean is None) and (self.running_var is None) == False
。所以使用全局的統計量。
對 batch 進行歸一化,公式為 \(y=\frac{x-\hat{E}[x]}{\sqrt{\hat{Var}[x]+\epsilon}}\),注意這里的均值和方差是running_mean 和 running_var,在網絡訓練時統計出來的全局均值和無偏樣本方差。
修改實例屬性
如果設置.track_running_stats==False
,此時bn_training
不變,仍未 False,所以仍然使用全局的統計量。也就是self.running_mean, self.running_var
中存放的內容。
整體而言,此時修改屬性沒有影響。
track_running_stats==False
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
此時bn_training = (self.running_mean is None) and (self.running_var is None) == True
。所以使用當前 batch 的統計量。
對 batch 進行歸一化,公式為 \(y=\frac{x-{E}[x]}{\sqrt{{Var}[x]+\epsilon }}\),注意這里的均值和方差是batch 自己的 mean 和 var,此時 BatchNorm 里不含有 running_mean 和 running_var。
注意此時使用的是無偏樣本方差(和訓練時不同),因此如果 batch_size=1,會使分母為 0,就報錯了。
修改實例屬性
如果設置.track_running_stats==True
,此時bn_training
不變,仍為 True,所以仍然使用當前 batch 的統計量。也就是忽略self.running_mean, self.running_var
中存放的內容。
此時的行為和未修改時一致。
匯總
圖片截圖自原始文檔。