前言
最近一段時間在搞模型量化(之前量化基礎為0),基本上查到了90%以上的成熟量化方案,QAT的方案真的非常不成熟,基本沒有開源好用的方案。賽靈思挺成熟但僅針對自己的框架,修改代價太大了。阿里的框架不成熟,至少我在看代碼的時候,他還在Fix-Bug。ONNX挺成熟,但使用人數基本沒有,其作為IR工具,很少有人拿他來訓練。。。。量化資料雖然多,但基本都是跑一個分類模型,至於檢測的量化少之又少。
目前狀態
環境:
- MMDetectionV2.15,已重構
- MQBenchV0.3,修改部分代碼,修復部分BUG
- 后端Torch--V1.9.1
- 后端Tensorrt--V8.2
- 后端ONNX-ONNXRuntime--V1.19
簡單試驗YOLOX-Nano
- FP32:mAP17.5%
- QAT(未加載PQT):直接訓練無法收斂、clip-grad收斂到較大loss無法下降。mAP無
- QAT(PQT)(無Augment):mAP12.2%
- QAT(PQT)(Augment):mAP18.4%
- QAT(INT8):mAP18.4%
更新1:
YOLOX-S
- FP32:mAP40.3%
- QAT:mAP39.7%
- QAT(INT8):mAP39.7%
未完全達到fp32的精度,YOLOX-S/Tiny都量化感知訓練精度相比fp32誤差在0.5以內
試驗的方式和阿里加速團隊基本一致,從試驗結果來看整體流程較為完整。
量化理論
量化的理論較為簡單(前向推理加速未涉及):
- \(r\) Float32 data
- \(q\) Quant data
- \(S\) Scale
- \(Z\) ZeroPoint
基本所有的論文都是圍繞以上四個公式進行的,未對具體論文進行總結,僅看代碼給出的幾個簡單的例子:
- Scale穩定性
由於scale的變化幅度過大會對訓練造成嚴重的震盪,所有較低頻率的修改scale才能促進訓練。
def scale_smooth(tensor: torch.Tensor):
log2t = torch.log2(tensor)
log2t = (torch.round(log2t) - log2t).detach() + log2t
return 2 ** log2t
- Scale和ZP的計算
這部分是優化最多且最有效的,因為一個好的初始化至關重要
- 都是比較簡單,代碼一看便明了
- 最大最小值
- 均值方差
- 直方圖
- Clip
- ......
- Learn/Fixed
Scale/ZP是訓練還是固定?
訓練的情況非常花時間,因為量化節點已經得插入上百個,如果再加上訓練,速度慢的可憐🥺!而且得長時間的tune,收斂緩慢。當然效果肯定比固定好🔥。
固定的情況會節約大量時間,但精度略低於訓練的情況。
量化方案
QAT:訓練量化
PQT:訓練后量化
目前主流還是使用PQT,比如Tensorrt、NCNN、MNN、ONNX。。。。基本前向推理的框架都支持訓練后量化。
少數使用QAT(僅在PQT精度較低的情況會使用),比如pytorch、ONNX、TF。。。基本只有訓練框架才支持。
QAT-PyTorch
原始的訓練方案 采用手動插入節點,目前已經完全廢棄,這里不做介紹。
當前主推的方案 基於torch.FX模塊進行,這里簡單介紹一下流程
torch.FX是DAG(雙向-有向圖)結構,和ONNX的Graph類似,但是核心不同。
- FX.Graph是由Node節點構成,每個Node節點表示一個Operate,Node是一個雙向指針。在前向計算的時候遍歷每個Node,Node調用Operate,這個操作可能是函數、類成員、類等。
- ONNX.Graph也是Node構成的,但是其使用list存儲,只有讀取某個Node才知道users和producer,而且權重等參數由另一個數據結構initializer存儲
以下是QAT-pytorch的簡易流程
注意: 其實FX模塊就和ONNX一樣,都是一個IR部件!!如何構建一個FX模塊是難點(本人僅理解大概流程,未做具體分析)
To-Torch
由於使用了MQBench進行操作,無法直接使用官方的FX-convert進行轉化。
由於使用了FX組件,也無法使用原始的convert進行轉化。
考慮到轉換的模型比較簡單,比如檢測YOLOX、YOLOV5,分類Resnet、MobileNet、ShuffleNet,所以沒必要專門寫一套MQBench和torch.convert的轉換工具。
本人直接手寫了一個類進行轉換,轉換采用torch原始的方式(將FX模塊解析出來),YOLOX精度已對齊。
- Torch的前向推理quantized op(比如QConv、QLinear...)只支持非對稱量化(torch.quint8),但是Quantize和DeQuantizae是支持對稱量化(torch.qint8),在實際的模型轉化中只能使用非對稱量化進行。
- function得用module替換
- 需要重寫torch的pqt代碼
注意: 采用此方案坑還是比較多的,建議還是寫個mqbench到torch.fx的中間工具。
To-ONNX
已完成,待寫文檔
To-TRT
已完成,待寫文檔
To-NCNN
已完成,待寫文檔