上一篇文章介紹了如何把 BatchNorm 和 ReLU 合並到 Conv 中,這篇文章會介紹具體的代碼實現。本文相關代碼都可以在 github 上找到。

Folding BN
回顧一下前文把 BN 合並到 Conv 中的公式:
其中,\(x\) 是卷積層的輸入,\(w\)、\(b\) 分別是 Conv 的參數 weight 和 bias,\(\gamma\)、\(\beta\) 是 BN 層的參數。
對於 BN 的合並,首先,我們需要熟悉 pytorch 中的 BatchNorm2d
模塊。
pytorch 中的BatchNorm2d
針對 feature map 的每一個 channel 都會計算一個均值和方差,所以公式 (1) 需要對 weight 和 bias 進行 channel wise 的計算。另外,BatchNorm2d
中有一個布爾變量 affine
,當該變量為 true 的時候,(1) 式中的 \(\gamma\) 和 \(\beta\) 就是可學習的, BatchNorm2d
會中有兩個變量:weight
和bias
,來分別存放這兩個參數。而當affine
為 false 的時候,就直接默認 \(\gamma=1\),\(\beta=0\),相當於 BN 中沒有可學習的參數。默認情況下,我們都設置 affine=True
。
我們沿用之前的代碼,先定義一個 QConvBNReLU
模塊:
class QConvBNReLU(QModule):
def __init__(self, conv_module, bn_module, qi=True, qo=True, num_bits=8):
super(QConvBNReLU, self).__init__(qi=qi, qo=qo, num_bits=num_bits)
self.num_bits = num_bits
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(num_bits=num_bits)
這個模塊會把全精度網絡中的 Conv2d 和 BN 接收進來,並重新封裝成量化的模塊。
接着,定義合並 BN 后的 forward 流程:
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training: # 開啟BN層訓練
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> (C,NHW),這一步是為了方便channel wise計算均值和方差
mean = y.mean(1)
var = y.var(1)
self.bn_module.running_mean = \
self.bn_module.momentum * self.bn_module.running_mean + \
(1 - self.bn_module.momentum) * mean
self.bn_module.running_var = \
self.bn_module.momentum * self.bn_module.running_var + \
(1 - self.bn_module.momentum) * var
else: # BN層不更新
mean = self.bn_module.running_mean
var = self.bn_module.running_var
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
這個過程就是對 Google 論文的那張圖的詮釋,跟一般的卷積量化的區別就是需要先獲得 BN 層的參數,再把它們 folding 到 Conv 中,最后跑正常的卷積量化流程。不過,根據論文的表述,我們還需要在 forward 的過程更新 BN 的均值、方差,這部分對應上面代碼 if self.training
分支下的部分。
然后,根據公式 (1),我們可以計算出 fold BN 后,卷積層的 weight 和 bias:
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std # 這一步計算gamma'
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else: # affine為False的情況,gamma=1, beta=0
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
上面的代碼直接參照公式 (1) 就可以看懂,其中gamma_
就是公式中的 \(\gamma'\)。由於前面提到,pytorch 的BatchNorm2d
中,\(\gamma\) 和 \(\beta\) 可能是可學習的變量,也可能默認為 1 和 0,因此根據affine
是否為True
分了兩種情況,原理上基本類似,這里就不再贅述。
合並ReLU
前面說了,ReLU 的合並可以通過在 ReLU 之后統計 minmax,再計算 scale 和 zeropoint 的方式來實現,因此這部分代碼非常簡單,就是在 forward 的時候,在做完 relu 后再統計 minmax 即可,對應代碼片段為:
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
...
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x) # <-- calculate minmax after relu
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
將 BN 和 ReLU 合並到 Conv 中,QConvBNReLU
模塊本身就是一個普通的卷積了,因此量化推理的過程和之前文章的QConv2d
一樣,這里不再贅述。
實驗
這里照例給出一些實驗結果。
本文的實驗還是在 mnist 上進行,我重新定義了一個包含 BN 的新網絡:
class NetBN(nn.Module):
def __init__(self, num_channels=1):
super(NetBN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
self.bn1 = nn.BatchNorm2d(40)
self.conv2 = nn.Conv2d(40, 40, 3, 1)
self.bn2 = nn.BatchNorm2d(40)
self.fc = nn.Linear(5 * 5 * 40, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5 * 5 * 40)
x = self.fc(x)
return x
量化該網絡的代碼如下:
def quantize(self, num_bits=8):
self.qconv1 = QConvBNReLU(self.conv1, self.bn1, qi=True, qo=True, num_bits=num_bits)
self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
self.qconv2 = QConvBNReLU(self.conv2, self.bn2, qi=False, qo=True, num_bits=num_bits)
self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
self.qfc = QLinear(self.fc, qi=False, qo=True, num_bits=num_bits)
整體的代碼風格基本和之前一樣,不熟悉的讀者建議先閱讀我之前的量化文章。
先訓練一個全精度網絡「相關代碼在 train.py 里面」,可以得到全精度模型的准確率是 99%。
然后,我又跑了一遍后訓練量化以及量化感知訓練,在不同量化 bit 下的精度如下表所示「由於學習率對量化感知訓練的影響非常大,這里順便附上每個 bit 對應的學習率」:
bit | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
后訓練量化 | 10% | 11% | 10% | 35% | 82% | 85% | 85% | 87% |
量化感知訓練 | 10% | 19% | 59% | 91% | 92% | 94% | 94% | 95% |
lr | 0.00001 | 0.0001 | 0.02 | 0.02 | 0.02 | 0.02 | 0.02 | 0.04 |
對比之前文章的結果,加入 BN 后,后訓練量化在精度上的下降更加明顯,而量化感知訓練依然能帶來較大的精度提升。但在低 bit 情況下,由於信息損失嚴重,網絡的優化會變的非常困難。
總結
這篇文章給出了 Folding BN 和 ReLU 的代碼實現,主要是想幫助初學者加深對公式細節的理解。至此,這系列教程基本告一段落,希望能幫助小白們快速入門這一領域。后面會不定期介紹一些我覺得有趣的 AI 技術,感興趣的讀者歡迎吃瓜吐槽。