盤點一下后訓練量化的基本操作


(本文首發於公眾號,沒事來逛逛)

這篇文章簡單聊聊后訓練量化的一些常規操作。

一些基礎知識

在此之前,還是需要先了解一下后訓練量化 (下面簡稱 PTQ,Post-training Quantization) 是啥?具體細節這里就不展開了,不熟悉的讀者歡迎看回我之前的文章 (神經網絡量化入門--后訓練量化)。簡單來說,后訓練量化就是在不重新訓練網絡 (即不更新 weight) 的前提下,獲取網絡的量化參數。

說到量化參數,就不得不祭出量化的基本公式了 (假設用非對稱量化,8bit):

\[r=S(q-z) \tag{1} \]

\[q=clip(round(\frac{r}{S}+Z),0,255) \tag{2} \]

這里面的 \(r\)\(q\) 分別表示量化前的浮點數和量化后的定點數。而 \(S\)\(Z\) 就是兩個重要的量化參數 scale (步長) 和 zero point (零點)。除此之外,還有兩個非常重要的量化參數:\(r_{min}\)\(r_{max}\),分別表示浮點數 \(r\) 的數值范圍。

\(S\)\(Z\)\(r_{min}\)\(r_{max}\) 構成了網絡量化里面四個最重要的量化參數,幾乎所有后訓練量化算法,都是為了找到這幾個東西。這里面,\(S\)\(Z\)\(r_{min}\)\(r_{max}\) 之間又是可以相互轉換的:

\[S = \frac{r_{max}-r_{min}}{q_{max}-q_{min}} \tag{3} \]

\[Z = clip(round(q_{max} - \frac{r_{max}}{S}), 0, 255) \tag{4} \]

公式里面的 \(q_{max}\)\(q_{min}\) 表示定點數 \(q\) 的數值范圍,在量化策略確定之后,數值一般就確定下來了。比如說,我這里采用 8bit 的非對稱量化,那么量化后的數值范圍一般就是 0~255,即 \(q_{max}=255\)\(q_{min}=0\)

公式 (3) 大家都能理解,但公式 (4) 寫法比較多,容易搞暈。我簡單畫了張圖,不明白的同學再琢磨琢磨。

對於一個常規的神經網絡來說,只要我們知道了每一層權重 (weight、bias) 和每一層特征 (feature map) 的 \(S\)\(Z\) (或者 \(r_{min}\)\(r_{max}\),反正可以相互轉換),理論上我們就可以用定點的方式跑網絡了,從而獲得內存訪問和計算效率上的提升。為什么這里要說常規呢?因為有一些激活函數在定點的情況下難以運行,這類函數通常還是只能以浮點的形式計算,在一些只能跑定點運算的芯片上令人頭疼不已。

后訓練量化

了解完這些基礎后,現在回到正題:如何找到合適的量化參數呢?

對於權重而言,在模型訓練完成后數值就基本確定了,而對於 feature map 來說,卻沒法事先得知,因此會用一批矯正數據集 (通常就是訓練集的一小部分) 跑一遍網絡,以此來統計每一層 feature map 的數值范圍。

有了權重和特征的數值范圍后,一種很直接的方法就是根據數值范圍的大小來確定 \(r_{min}\)\(r_{max}\)。我之前的文章()就簡單采用了這種方法。但這種方法容易受到噪聲的影響,比如,有些 weight 或者 feature 中可能存在某些離群點,它的數值較大,但對結果影響又很小,如果把這些數值也統計到 minmax 里面,就容易造成浪費。

舉個栗子,如果某個 weight 里面的數值是 [-0.1, 0.2, 0.3, 255.1],那我們統計出來的 minmax 就是 -0.1 和 255.1,如此一來,0.2、0.3 這樣的數值就會被映射到同一個定點數,信息損失相當嚴重,而它們對結果影響可能遠大於 255.1。因此,在這種情況下,我們寧願把 255.1 損失掉,也希望盡可能把 0.2、0.3 保持下來。

一些簡單的改進

那該如何改進呢?

1. 直方圖截斷

既然離群點影響很大,那最容易想到的解法就是排除這些離群點的干擾。我們可以把 weight 或者 feature map 的數值范圍統計出一個直方圖,根據直方圖舍棄前后 m% 的數值,直接用剩下的數值來確定 minmax。

2. 滑動平均

除此之外,還有一種對 feature map 比較有效的統計方法。 這也是 Google 論文提到的一種技巧\(^1\)。我們把矯正數據集分為幾個 batch,逐次輸入到網絡中統計數值。每次更新數值范圍時,按照 \(r_{max}^t=r_{max}^t*(1-\alpha)+r_{max}^{t-1}*\alpha\) 來更新,其中,\(r_{max}^{t-1}\) 是上一次統計到的最大值。通過控制 \(\alpha\) 的數值,可以控制新數據對歷史統計數據的影響,讓最終統計到的數值能大致涵蓋大部分數值,但又不會被一些離群點主導。

3. 均值和方差

一個不成文的約定:我們通常會假設 weight 和 feature 的數值呈正態分布。在此假設下,我們可以統計出測試數據中 weight 或者 feature 的均值 \(\mu\) 和方差 \(\sigma\),然后,根據正態分布的性質,在區間 \((\mu-3\sigma, \mu+3\sigma)\) 之間的數值占了 99+%,因此,可以令 \(r_{min}=\mu-3\sigma\)\(r_{max}=\mu+3\sigma\),這樣就基本涵蓋了大部分數值,也避免了一些離群點的影響。

當然,如果實際的數值分布不是正態的,比如,是個雙峰分布,那可能就 gg 了。

加點數學的味道

以上這些方法都比較 tricky (直方圖要舍棄多少才合適?滑動平均的 \(\alpha\) 怎么設置?正態分布一定要取 \(3\sigma\)?萬一離群點很重要怎么辦?),效果好壞全靠靈巧的雙手。下面介紹幾種更加 mathematic 的方法,看起來理論更加完備一些 (雖然對於神經網絡來說, 有時候仍然很玄學)。

雖然扯到數學,但其實也沒什么高大上的,無非就是找一些方法,可以讓這些 tricky 的事情更加自動化一些。

這其中最關鍵的,就是找到一種度量信息損失的方法,可以告訴我們,當前取到的 minmax 值合不合適,是不是精度最高的。這些度量方法中,最常用的如歐式距離 (L2距離)、L1距離、KL散度、余弦距離等。

1. 搜索minmax

確定好度量方法后,我們就可以自動化地搜索最合適的數值范圍了。最簡單的思路就是在原本的 minmax 區間內,逐步搜索一個更小的數值范圍,然后計算這個范圍內的信息做了量化后有多少信息損失,損失越小,證明這個數值范圍越合適。

江湖中人用的較多的 TensorRT 量化算法就是基於 KL 散度來搜索 minmax 的。TRT 采用的是 8bit 對稱量化,即正數區間量化到 [0, 127],負數區間量化到 [-128, 0)。量化的大致過程如下:

  1. 首先根據矯正數據集確定數值范圍 [\(r_{min}\), \(r_{max}\)];
  2. 把這個范圍區間划分為 2048 份 (相當於離散化成 2048 個 bin 的直方圖,具體多少 bin 可以調整);
  3. 以最前面的 128 個 bin 作為基准,逐次向后搜索,每次擴增一個 bin 的長度,得到一個新的數值范圍。然后把這個數值范圍重新划分為一個 128 個 bin 的直方圖 \(Q\) (這一步相當於舍棄了部分數值信息,並做了量化);
  4. 那要如何評價當前這個數值范圍是否合適呢?這個時候 KL 散度就能派上用場了。我們把剩下那些沒有搜索到的數值壓縮到當前搜索到的 bin 上,得到一個信息基本沒有損失的直方圖 \(P\),如果我們之前搜索到的 \(Q\)\(P\) 相比信息損失最小 (即 KL 散度最小),那這個 \(Q\) 對應的數值范圍就是最好的數值范圍。不巧的是,KL 散度需要兩個直方圖的 bin 是一樣的 (L1 距離等也有這個要求),而 \(Q\) 之前已經被量化到 128 個 bin 了。為了解決這個問題,需要把 \(Q\)反量化到跟 \(P\) 的 bin 數相同,這樣就可以計算信息損失了。
  5. 重復步驟 3、4,記錄每次搜索的 KL 散度大小,直到搜索完整個范圍。KL 散度最小的搜索范圍,就是理論上信息損失最小的 minmax。

其中的一些關鍵操作如量化、反量化等,限於篇幅這里就不展開講了。

有人可能會問:按照這樣搜索,那是不是最后一次把整個數值范圍都包含進去的時候 KL 散度最小呢?畢竟搜索到最后一步我們沒有舍棄任何數值。有這種疑惑是因為沒有考慮到量化的影響。由於大部分情況下,數值分布都近似於正太分布 (即大部分數值會集中在一個區間內),而隨着搜索范圍增大,離群點會越來越多,但中間那些真正有用的、比較集中的數值就只能用更少的 bin 來表達 (要知道總共只有 128 個 bin 可以承載信息)。因此,絕大部分情況下,舍棄離群點 (outlier) 獲得的收益往往是更大的。

以上就是 TRT 量化的大致過程,基本套路就是:從一個小的搜索范圍逐漸擴大出去,每次搜索都量化一遍信息 (比如划分成固定 bin 數的直方圖),然后用一種度量方式 (KL 散度、L1 距離等) 來衡量完整信息和量化信息之間的差異,差異最小的區間就是我們需要的 minmax。

2. 搜索S和Z

除了 minmax,我們也可以搜索合適的 \(S\)\(Z\)。這里的套路和前面是類似的,也是根據量化前后的信息損失來找出最優解。

假設量化前的浮點 weight 或 feature map 為向量 \(r\),那么量化后為:

\[q=clip(round(\frac{r}{S}+Z),0,255) \tag{5} \]

再進行反量化后得到:

\[\begin{align} \hat r&=S*(q-Z) \notag \\ &=S*(clip(round(\frac{r}{S}+Z),0,255)-Z) \tag{6} \end{align} \]

接下來就可以度量量化的信息損失了,在論文 EasyQuant\(^2\) 中使用了余弦相似性,因此這里我們也以余弦相似性為例。

假設矯正數據集總共有 \(N\) 個樣本,那么平均相似性為:

\[\frac{1}{N}\sum_{i}^Ncos(r_i,\hat{r_i})=\frac{1}{N}\sum_i^N\frac{r_i \hat{r_i}} {||r_i||||\hat{r_i}||} \tag{7} \]

而我們要求解的,就是使得這個相似性最大的 \(S\)\(Z\) (余弦相似性越大,信息損失越小):

\[\underset{S, Z} {\operatorname {max}} \frac{1}{N}\sum_i^N\frac{r_i \hat{r_i}} {||r_i||||\hat{r_i}||} \tag{8} \]

搜索 \(S\)\(Z\) 的方法有很多,比如可以參考前面 TRT 的思路,先設定 \(S\)\(Z\) 的范圍,然后我們用兩個循環分別對 \(S\)\(Z\) 進行搜索遍歷,計算每一步搜索的相似性分數,分數最大的就是我們需要的 \(S\)\(Z\)。這種方法就是通常所說的 Grid Search

不過,由於我們已經有了 \(S\)\(Z\) 的解析式了,所以完全可以梯度下降法,甚至直接對 (8) 式求解析解的方法,獲得最優解。不過暫時沒見過有文章這樣處理,估計是因為解析式里面 round 這個函數不好求導吧。

當然,在網絡結構比較復雜的情況下,單獨針對每一層求解量化參數,並不一定能獲得整個網絡的最優精度,因此在 EasyQuant\(^2\) 論文中有很多 trick 來更好地求解量化精度,這個有機會后面再細講。

總結

水了這么多,總算可以結尾了。這篇文章主要介紹了后訓練量化的一些常用操作,包括如何用直方圖簡單地截取 minmax,以及 TensorRT 量化算法的套路等等。事實上,這些后訓練量化的方法也完全可以用到量化感知訓練中,后者無非是多了對權重的更新學習而已。

講完基本操作,后面就是進階版本了,可能會介紹一些更加前沿的后訓練量化的論文。感興趣的老鐵點個贊和在看可好。

參考

  1. Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
  2. EasyQuant: Post-training Quantization via Scale Optimization

歡迎關注我的公眾號:大白話AI,立志用大白話講懂AI。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM