PyTorch之BN核心參數詳解


PyTorch之BN核心參數詳解

原始文檔:https://www.yuque.com/lart/ugkv9f/qoatss

affine

初始化時修改

affine 設為 True 時,BatchNorm 層才會學習參數 gamma 和 beta,否則不包含這兩個變量,變量名是 weight 和 bias。

.train()

  • 如果affine==True,則對歸一化后的 batch 進行仿射變換,即乘以模塊內部的 weight(初值是[1., 1., 1., 1.])然后加上模塊內部的 bias(初值是[0., 0., 0., 0.]),這兩個變量會在反向傳播時得到更新。
  • 如果affine==False,則 BatchNorm 中不含有 weight 和 bias 兩個變量,什么都都不做。

.eval()

  • 如果affine==True,則對歸一化后的 batch 進行放射變換,即乘以模塊內部的 weight 然后加上模塊內部的 bias,這兩個變量都是網絡訓練時學習到的。
  • 如果affine==False,則 BatchNorm 中不含有 weight 和 bias 兩個變量,什么都不做。

修改實例屬性

無影響,仍按照初始化時的設定。

track_running_stats

由於 BN 的前向傳播中涉及到了該屬性,所以實例屬性的修改會影響最終的計算過程。

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps',
                     'num_features', 'affine']
    num_features: int
    eps: float
    momentum: float
    affine: bool
    track_running_stats: bool
    # WARNING: weight and bias purposely not defined here.
    # See https://github.com/pytorch/pytorch/issues/39670

    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True
    ) -> None:
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()
    ...

class _BatchNorm(_NormBase):
    ...

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:  # type: ignore
                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        r"""
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.

        可以看到這里的bn_training控制的是,數據運算使用當前batch計算得到的統計量(True)
        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        r"""
        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
        used for normalization (i.e. in eval mode when buffers are not None).

        這里強調的是統計量buffer的使用條件(self.running_mean, self.running_var)
        - training==True and track_running_stats==False, 這些屬性被傳入F.batch_norm中時,均替換為None
        - training==True and track_running_stats==True, 會使用這些屬性中存放的內容
        - training==False and track_running_stats==True, 會使用這些屬性中存放的內容
        - training==False and track_running_stats==False, 會使用這些屬性中存放的內容
        """
        assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
        assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
        return F.batch_norm(
            input,
            # If buffers are not to be tracked, ensure that they won't be updated
            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,
            self.weight, self.bias, bn_training, exponential_average_factor, self.eps)

.train()

注意代碼中的注釋:Buffers are only updated if they are to be tracked and we are in training mode. 即僅當為訓練模式且track_running_stats==True時會更新這些統計量 buffer。

另外,此時self.training==Truebn_training=True

track_running_stats==True

BatchNorm 層會統計全局均值 running_mean 和方差 running_var,而對 batch 歸一化時,僅使用當前 batch 的統計量。

            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

使用 momentum 更新模塊內部的 running_mean。

  • 如果 momentum 是 None,那么就是用累計移動平均(這里會使用屬性self.num_batches_tracked來統計已經經過的 batch 數量),否則就使用指數移動平均(使用 momentum 作為系數)。二者的更新公式基本框架是一樣的:\(x_{new}=(1 - factor) \times x_{cur} + factor \times x_{batch}\)
    ,只是具體的 \(factor\) 有所不同。
    • \(x_{new}\) 代表更新后的 running_mean 和 running_var;
    • \(x_{cur}\) 表示更新前的running_mean和running_var;
    • $x_{batch}$ 表示當前 batch 的均值和無偏樣本方差。
  • 累計移動平均的更新中 \(factor=1/num\_batches\_tracked\)
  • 指數移動平均的更新公式是 \(factor=momentum\)
修改實例屬性

如果設置.track_running_stats==False,此時self.num_batches_tracked不會更新,而且exponential_average_factor也不會被重新調整。
而由於:

            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,

且此時self.training==True,並且self.track_running_stats==False,所以送入F.batch_normself.running_mean&self.running_var兩個參數都是 None。
也就是說,此時和直接在初始化中設置**track_running_stats==False**是一樣的效果。
但是要小心這里的~~exponential_average_factor~~的變化。不過由於通常我們初始化 BN 時,僅僅會送入~~num_features~~,所以默認會使用~~exponential_average_factor = self.momentum~~來構造指數移動平均更新運行時統計量。(此時exponential_average_factor不會發揮作用)

track_running_stats==False

則 BatchNorm 中不含有 running_mean 和 running_var 兩個變量,也就是僅僅使用當前 batch 的統計量來歸一化 batch。

            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
修改實例屬性

如果設置.track_running_stats==True,此時self.num_batches_tracked仍然不會更新,因為其初始值是 None。
整體來看,這樣的修改並沒有實際影響。

.eval()

此時self.training==False

            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,

此時送入F.batch_norm的兩個統計量 buffer 和初始化時的結果是一致的。

track_running_stats==True

            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

此時bn_training = (self.running_mean is None) and (self.running_var is None) == False。所以使用全局的統計量。
對 batch 進行歸一化,公式為 \(y=\frac{x-\hat{E}[x]}{\sqrt{\hat{Var}[x]+\epsilon}}\),注意這里的均值和方差是running_mean 和 running_var,在網絡訓練時統計出來的全局均值和無偏樣本方差

修改實例屬性

如果設置.track_running_stats==False,此時bn_training不變,仍未 False,所以仍然使用全局的統計量。也就是self.running_mean, self.running_var中存放的內容。
整體而言,此時修改屬性沒有影響。

track_running_stats==False

            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)

此時bn_training = (self.running_mean is None) and (self.running_var is None) == True。所以使用當前 batch 的統計量。
對 batch 進行歸一化,公式為 \(y=\frac{x-{E}[x]}{\sqrt{{Var}[x]+\epsilon }}\),注意這里的均值和方差是batch 自己的 mean 和 var,此時 BatchNorm 里不含有 running_mean 和 running_var。
注意此時使用的是無偏樣本方差(和訓練時不同),因此如果 batch_size=1,會使分母為 0,就報錯了。

修改實例屬性

如果設置.track_running_stats==True,此時bn_training不變,仍為 True,所以仍然使用當前 batch 的統計量。也就是忽略self.running_mean, self.running_var中存放的內容。
此時的行為和未修改時一致。

匯總

圖片截圖自原始文檔。

參考


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM