神經網絡量化入門--后訓練量化


上一篇文章介紹了矩陣量化的基本原理,並推廣到卷積網絡中。這一章開始,我會逐步深入到卷積網絡的量化細節中,並用 pytorch 從零搭建一個量化模型,幫助讀者實際感受量化的具體流程。

本章中,我們來具體學習最簡單的量化方法——后訓練量化「post training quantization」

由於本人接觸量化不久,如表述有錯,歡迎指正。

卷積層量化

卷積網絡最核心的要素是卷積,前文雖然有提及卷積運算的量化,但省略了很多細節,本文繼續深入卷積層的量化。

這里我們繼續沿用之前的公式,用 \(S\)\(Z\) 表示 scale 和 zero point,\(r\) 表示浮點實數,\(q\) 表示定點整數。

假設卷積的權重 weight 為 \(w\),bias 為 \(b\),輸入為 \(x\),輸出的激活值為 \(a\)。由於卷積本質上就是矩陣運算,因此可以表示成:

\[a=\sum_{i}^N w_i x_i+b \tag{1} \]

由此得到量化的公式:

\[S_a (q_a-Z_a)=\sum_{i}^N S_w(q_w-Z_w)S_x(q_x-Z_x)+S_b(q_b-Z_b) \tag{2} \]

\[q_a=\frac{S_w S_x}{S_a}\sum_{i}^N (q_w-Z_w)(q_x-Z_x)+\frac{S_b}{S_a}(q_b-Z_b)+Z_a \tag{3} \]

這里面非整數的部分就只有 \(\frac{S_w S_x}{S_a}\)\(\frac{S_b}{S_a}\),因此接下來就是把這部分也變成定點運算。

對於 bias,由於 \(\sum_{i}^N (q_w-Z_w)(q_x-Z_x)\) 的結果通常會用 int32 的整數存儲,因此 bias 通常也量化到 int32。這里我們可以直接用 \(S_w S_x\) 來代替 \(S_b\),由於 \(S_w\)\(S_x\) 都是對應 8 個 bit 的縮放比例,因此 \(S_w S_x\) 最多就放縮到 16 個 bit,用 32bit 來存放 bias 綽綽有余,而 \(Z_b\) 則直接記為 0。

因此,公式 (3) 再次調整為:

\[\begin{align} q_a&=\frac{S_w S_x}{S_a}(\sum_{i}^N(q_w-Z_w)(q_x-Z_x)+q_b)+Z_a \notag \\ &=M(\sum_{i}^N q_wq_x-\sum_i^N q_wZ_x-\sum_i^N q_xZ_w+\sum_i^NZ_wZ_x+q_b)+Z_a \tag{4} \end{align} \]

其中,\(M=\frac{S_w S_x}{S_a}\)

根據上一篇文章的介紹,\(M\) 可以通過一個定點小數加上 bit shift 來實現,因此公式 (4) 完全可以通過定點運算進行計算。由於 \(Z_w\)\(q_w\)\(Z_x\)\(q_b\) 都是可以事先計算的,因此 \(\sum_i^N q_wZ_x\)\(\sum_i^NZ_wZ_x+q_b\) 也可以事先計算好,實際 inference 的時候,只需要計算 \(\sum_{i}^N q_wq_x\)\(\sum_i^N q_xZ_w\) 即可。

卷積網絡量化流程

了解完整個卷積層的量化,現在我們再來完整過一遍卷積網絡的量化流程。

我們繼續沿用前文的小網絡:

其中,\(x\)\(y\) 表示輸入和輸出,\(a_1\)\(a_2\) 是網絡中間的 feature map,\(q_x\) 表示 \(x\) 量化后的定點數,\(q_{a1}\) 等同理。

在后訓練量化中,我們需要一些樣本來統計 \(x\)\(a_1\)\(a_2\) 以及 \(y\) 的數值范圍「即 min, max」,再根據量化的位數以及量化方法來計算 scale 和 zero point。

本文中,我們先采用最簡單的量化方式,即統計 min、max 后,按照線性量化公式:

\[S = \frac{r_{max}-r_{min}}{q_{max}-q_{min}} \tag{5} \]

\[Z = round(q_{max} - \frac{r_{max}}{S}) \tag{6} \]

來計算 scale 和 zero point。

需要注意的是,除了第一個 conv 需要統計輸入 \(x\) 的 min、max 外,其他層都只需要統計中間輸出 feature 的 min、max 即可。另外,對於 relu、maxpooling 這類激活函數來說,它們會沿用上一層輸出的 min、max,不需要額外統計,即上圖中 \(a_1\)\(a_2\) 會共享相同的 min、max 「為何這些激活函數可以共享 min max,以及哪些激活函數有這種性質,之后有時間可以細說」。

因此,在最簡單的后訓練量化算法中,我們會先按照正常的 forward 流程跑一些數據,在這個過程中,統計輸入輸出以及中間 feature map 的 min、max。等統計得差不多了,我們就可以根據 min、max 來計算 scale 和 zero point,然后根據公式 (4) 中的,對一些數據項提前計算。

之后,在 inference 的時候,我們會先把輸入 \(x\) 量化成定點整數 \(q_x\),然后按照公式 (4) 計算卷積的輸出 \(q_{a1}\),這個結果依然是整型的,然后繼續計算 relu 的輸出 \(q_{a2}\)。對於 fc 層來說,它本質上也是矩陣運算,因此也可以用公式 (4) 計算,然后得到 \(q_y\)。最后,根據 fc 層已經計算出來的 scale 和 zero point,推算回浮點實數 \(y\)。除了輸入輸出的量化和反量化操作,其他流程完全可以用定點運算來完成。

pytorch實現

有了上面的鋪墊,現在開始用 pytorch 從零搭建量化模型。

下文的代碼都可以在 https://github.com/Jermmy/pytorch-quantization-demo 上找到。

基礎量化函數

首先,我們需要把量化的基本公式,也就是公式 (5)(6) 先實現:

def calcScaleZeroPoint(min_val, max_val, num_bits=8):
    qmin = 0.
    qmax = 2. ** num_bits - 1.
    scale = float((max_val - min_val) / (qmax - qmin)) # S=(rmax-rmin)/(qmax-qmin)

    zero_point = qmax - max_val / scale    # Z=round(qmax-rmax/scale)

    if zero_point < qmin:
        zero_point = qmin
    elif zero_point > qmax:
        zero_point = qmax
    
    zero_point = int(zero_point)

    return scale, zero_point

def quantize_tensor(x, scale, zero_point, num_bits=8, signed=False):
    if signed:
        qmin = - 2. ** (num_bits - 1)
        qmax = 2. ** (num_bits - 1) - 1
    else:
        qmin = 0.
        qmax = 2.**num_bits - 1.
 
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()     # q=round(r/S+Z)
    
    return q_x.float()  # 由於pytorch不支持int類型的運算,因此我們還是用float來表示整數
 
def dequantize_tensor(q_x, scale, zero_point):
    return scale * (q_x - zero_point)    # r=S(q-Z)

前面提到,在后訓練量化過程中,需要先統計樣本以及中間層的 min、max,同時也頻繁涉及到一些量化、反量化操作,因此我們可以把這些功能都封裝成一個 QParam 類:

class QParam:

    def __init__(self, num_bits=8):
        self.num_bits = num_bits
        self.scale = None
        self.zero_point = None
        self.min = None
        self.max = None

    def update(self, tensor):
        if self.max is None or self.max < tensor.max():
            self.max = tensor.max()
        
        if self.min is None or self.min > tensor.min():
            self.min = tensor.min()
        
        self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.num_bits)
    
    def quantize_tensor(self, tensor):
        return quantize_tensor(tensor, self.scale, self.zero_point, num_bits=self.num_bits)

    def dequantize_tensor(self, q_x):
        return dequantize_tensor(q_x, self.scale, self.zero_point)

上面的 update 函數就是用來統計 min、max 的。

量化網絡模塊

下面要來實現一些最基本網絡模塊的量化形式,包括 conv、relu、maxpooling 以及 fc 層。

首先我們定義一個量化基類,這樣可以減少一些重復代碼,也能讓代碼結構更加清晰:

class QModule(nn.Module):

    def __init__(self, qi=True, qo=True, num_bits=8):
        super(QModule, self).__init__()
        if qi:
            self.qi = QParam(num_bits=num_bits)
        if qo:
            self.qo = QParam(num_bits=num_bits)

    def freeze(self):
        pass

    def quantize_inference(self, x):
        raise NotImplementedError('quantize_inference should be implemented.')

這個基類規定了每個量化模塊都需要提供的方法。

首先是 __init__ 函數,除了指定量化的位數外,還需指定是否提供量化輸入 (qi) 及輸出參數 (qo)。在前面也提到,不是每一個網絡模塊都需要統計輸入的 min、max,大部分中間層都是用上一層的 qo 來作為自己的 qi 的,另外有些中間層的激活函數也是直接用上一層的 qi 來作為自己的 qi 和 qo。

其次是 freeze 函數,這個函數會在統計完 min、max 后發揮作用。正如上文所說的,公式 (4) 中有很多項是可以提前計算好的,freeze 就是把這些項提前固定下來,同時也將網絡的權重由浮點實數轉化為定點整數

最后是 quantize_inference,這個函數主要是量化 inference 的時候會使用。實際 inference 的時候和正常的 forward 會有一些差異,可以根據之后的代碼體會一下。

下面重點看量化卷積層的實現:

class QConv2d(QModule):

    def __init__(self, conv_module, qi=True, qo=True, num_bits=8):
        super(QConv2d, self).__init__(qi=qi, qo=qo, num_bits=num_bits)
        self.num_bits = num_bits
        self.conv_module = conv_module
        self.qw = QParam(num_bits=num_bits)

    def freeze(self, qi=None, qo=None):
        
        if hasattr(self, 'qi') and qi is not None:
            raise ValueError('qi has been provided in init function.')
        if not hasattr(self, 'qi') and qi is None:
            raise ValueError('qi is not existed, should be provided.')

        if hasattr(self, 'qo') and qo is not None:
            raise ValueError('qo has been provided in init function.')
        if not hasattr(self, 'qo') and qo is None:
            raise ValueError('qo is not existed, should be provided.')

        if qi is not None:
            self.qi = qi
        if qo is not None:
            self.qo = qo
        self.M = self.qw.scale * self.qi.scale / self.qo.scale

        self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
        self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point

        self.conv_module.bias.data = quantize_tensor(self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale, zero_point=0, signed=True)

    def forward(self, x):
        if hasattr(self, 'qi'):
            self.qi.update(x)

        self.qw.update(self.conv_module.weight.data)

        self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
        self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)

        x = self.conv_module(x)

        if hasattr(self, 'qo'):
            self.qo.update(x)

        return x
      
    def quantize_inference(self, x):
        x = x - self.qi.zero_point
        x = self.fc_module(x)
        x = self.M * x + self.qo.zero_point
        return x

這個類基本涵蓋了最精華的部分。

首先是 __init__ 函數,可以看到我傳入了一個 conv_module 模塊,這個模塊對應全精度的卷積層,另外的 qw 參數則是用來統計 weight 的 min、max 以及對 weight 進行量化用的。

其次是 freeze 函數,這個函數主要就是計算公式 (4) 中的 \(M\)\(q_w\) 以及 \(q_b\)。由於完全實現公式 (4) 的加速效果需要更底層代碼的支持,因此在 pytorch 中我用了更簡單的實現方式,即優化前的公式 (4):

\[q_a=M(\sum_{i}^N(q_w-Z_w)(q_x-Z_x)+q_b)+Z_a \tag{7} \]

這里的 \(M\) 本來也需要通過移位來實現定點化加速,但 pytorch 中 bit shift 操作不好實現,因此我們還是用原始的乘法操作來代替。

注意到 freeze 函數可能會傳入 qi 或者 qo​,這也是之前提到的,有些中間的模塊不會有自己的 qi,而是復用之前層的 qo 作為自己的 qi。

接着是 forward 函數,這個函數和正常的 forward 一樣,也是在 float 上進行的,只不過需要統計輸入輸出以及 weight 的 min、max 而已。有讀者可能會疑惑為什么需要對 weight 量化到 int8 然后又反量化回 float,這里其實就是所謂的偽量化節點,因為我們在實際量化 inference 的時候會把 weight 量化到 int8,這個過程本身是有精度損失的 (來自四舍五入的 round 帶來的截斷誤差),所以在統計 min、max 的時候,需要把這個過程帶來的誤差也模擬進去。

最后是 quantize_inference 函數,這個函數在實際 inference 的時候會被調用,對應的就是上面的公式 (7)。注意,這個函數里面的卷積操作是在 int 上進行的,這是量化推理加速的關鍵「當然,由於 pytorch 的限制,我們仍然是在 float 上計算,只不過數值都是整數。這也可以看出量化推理是跟底層實現緊密結合的技術」。

理解 QConv2d 后,其他模塊基本上異曲同工,這里不再贅述。

完整的量化網絡

我們定義一個簡單的卷積網絡:

class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20) # 這里用分組網絡,可以增大量化帶來的誤差
        self.fc = nn.Linear(5*5*40, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*40)
        x = self.fc(x)
        return x

接下來就是把這個網絡的每個模塊進行量化,我們單獨定義一個 quantize 函數來逐個量化每個模塊:

class Net(nn.Module):

    def quantize(self, num_bits=8):
        self.qconv1 = QConv2d(self.conv1, qi=True, qo=True, num_bits=num_bits)
        self.qrelu1 = QReLU()
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConv2d(self.conv2, qi=False, qo=True, num_bits=num_bits)
        self.qrelu2 = QReLU()
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, qi=False, qo=True, num_bits=num_bits)

注意,這里只有第一層的 conv 需要 qi,后面的模塊基本是復用前面層的 qo 作為當前層的 qi。

接着定義一個 quantize_forward 函數來統計 min、max,同時模擬量化誤差:

class Net(nn.Module):
    
    def quantize_forward(self, x):
        x = self.qconv1(x)
        x = self.qrelu1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qrelu2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

下面的 freeze 函數會在統計完 min、max 后對一些變量進行固化:

class Net(nn.Module):

    def freeze(self):
        self.qconv1.freeze()
        self.qrelu1.freeze(self.qconv1.qo)
        self.qmaxpool2d_1.freeze(self.qconv1.qo)
        self.qconv2.freeze(qi=self.qconv1.qo)
        self.qrelu2.freeze(self.qconv2.qo)
        self.qmaxpool2d_2.freeze(self.qconv2.qo)
        self.qfc.freeze(qi=self.qconv2.qo)

由於我們在量化網絡的時候,有些模塊是沒有定義 qi 的,因此這里需要傳入前面層的 qo 作為當前層的 qi。

最后是 quantize_inference 函數,就是實際 inference 的時候用到的函數:

class Net(nn.Module):
  
    def quantize_inference(self, x):
        qx = self.qconv1.qi.quantize_tensor(x)
        qx = self.qconv1.quantize_inference(qx)
        qx = self.qrelu1.quantize_inference(qx)
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        qx = self.qconv2.quantize_inference(qx)
        qx = self.qrelu2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)
        qx = self.qfc.quantize_inference(qx)
        out = self.qfc.qo.dequantize_tensor(qx)
        return out

這里我們會將輸入 x 先量化到 int8,然后就是全量化的定點運算,得到最后一層的輸出后,再反量化回 float 即可。

訓練全精度網絡

這一部分代碼在 train.py 中,我們用 mnist 數據集來訓練上面的網絡:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader = torch.utils.data.DataLoader(
  datasets.MNIST('data', train=True, download=True, 
                 transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
                 ])),
  batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
  datasets.MNIST('data', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
  ])),
  batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)

model = Net().to(device)

具體訓練細節比較簡單,這里不再贅述。

訓練完成后,我測試得到的准確率在 98% 左右。

后訓練量化

這一部分代碼在 post_training_quantize.py 中。

我們先加載全精度模型的參數:

model = Net()
model.load_state_dict(torch.load('ckpt/mnist_cnn.pt'))

然后對網絡進行量化:

model.quantize(num_bits=8)

接下來就是用一些訓練數據來估計 min、max:

def direct_quantize(model, test_loader):
    for i, (data, target) in enumerate(test_loader, 1):
        output = model.quantize_forward(data)
        if i % 200 == 0:
            break
    print('direct quantization finish')

簡單起見,我們就跑 200 個迭代。

然后,我們把量化參數都固定下來,並進行全量化推理:

model.freeze()

def quantize_inference(model, test_loader):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        output = model.quantize_inference(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Quant Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))

quantize_inference(model, test_loader)

由於很多細節都封裝在量化網絡的模塊中了,因此外部調用的代碼跟全精度模型其實很類似。

我自己測試了 bit 數為 1~8 的准確率,得到下面這張折線圖:

發現,當 bit >= 3 的時候,精度幾乎不會掉,bit = 2 的時候精度下降到 69%,bit = 1 的時候則下降到 10%。

這一方面是 mnist 分類任務比較簡單,但也說明神經網絡中的冗余量其實非常大,所以量化在分類網絡中普遍有不錯的效果「不過 bit =3 或 4 的時候效果依然這么好,讓我依稀覺得代碼里面應該有 bug,后續還要反復檢查」。

總結

這篇文章主要補充了卷積層量化的細節,包括 bias 的量化,以及實際 inference 時一些優化的操作。並梳理了完整的卷積網絡量化的流程。然后重點用 pytorch 從零搭建一個量化模型來幫助大家理解其中的細節,以及后訓練量化算法的過程。代碼是參考了這篇文章,加上自己拍腦袋構思的,存在很多不足之處,而且應該有不少 bug 存在,也歡迎大家指正。

之后的文章將繼續講述量化感知訓練的流程,並補充其他量化的細節「例如 conv+relu 的合並等」,感謝大家賞臉關注。

參考


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM