以下文章來源於極市平台 ,作者CV開發者都愛看的

專注計算機視覺前沿資訊和技術干貨,官網:www.cvmart.net
點擊上方“計算機視覺工坊”,選擇“星標”
干貨第一時間送達
導讀
training-aware-quantization是在訓練中模擬量化行為,在訓練中用浮點來保存定點參數,最后inference的時候,直接采用定點參數。本文闡述了作者選用darknet框架來實現量化的過程,包括如何在訓練中融合BN到CONV以及Uint8推理實現等。量化表現的結果顯示前向時間相比於原來的darknet壓縮明顯,同時精度下降非常低。
量化簡介
在實際神經網絡在例如端側的部署時,由於內存,帶寬或者最重要計算資源的限制,通常會采用量化等手段來加速神經網絡的表現。量化的意思即是將原來浮點運算轉化為定點運算,例如最常見的8bit量化,無論是int8還是uint8,都是將浮點的區間參數映射到256個離散區間上。這樣原來32位的運算就變成了8位的運算
這里我們以非對稱量化到uint8舉例,其中S代表量化因子(scale factor), Z代表zero point.
量化的優點非常明顯,即使除去后處理,反量化或者非對稱量化帶來額外運算,單張圖片的推理速度通常都能獲得2-3倍的提升(這里不討論針對硬件進行特殊優化帶來的加速),但是隨之而來的就是量化造成的精度下降問題。
簡單來說,量化造成精度損失主要來自兩個方面:
-
取整損失,例如r = [6.8, 7.2, -0.6], scale = (7.2+0.6)/127 = 0.061417, q1 = 7.2/scale = 117.23,那么他的量化值就是117,有了0.23的損失
-
截斷損失 ,因為scale是取最優區間,那么邊界的點勢必會有超過最大量化值的情況,這些離群點就會被忽略掉,量化的最大最小值區間相比於原數據分布就有了截斷損失
為了能夠減少量化過程中的精度損失,我們參考google的論文
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
這種方法屬於aware training quantization,與之對應的是post training quantization,后面一種方法是tensorRT使用的量化方法,后面有機會可以把實現的代碼上傳到github上。
事實上,學術界認為8bit的量化已經飽和了,已經開始做4bit的量化研究了,但是在實際的工作過程中,發現對於較小的識別網絡,8bit的量化效果依然不是令人非常滿意。
量化實現
為了方便的部署到嵌入式端,我最初選擇實現框架定在實現語言為C或者C++,最終選定的框架為darknet,一方面darknet在工業界有着不錯的應用群體,二來框架簡單直接,實現起來非常方便,同時還可以驗證反向過程是否正確。在復現過程中,為了能夠將算法成功的集成進去,對darknet做了許多小的修改,正好這里也記錄一下。
代碼鏈接:
https://github.com/ArtyZe/yolo_quantization
偽量化
相信對量化了解的同學都讀過這篇文章,tf-lite都是用的這種量化方式。區別於訓練后量化的方式,google采用的是在訓練過程中加入偽量化來模擬量化過程中由於取整造成的精度損失。
那么偽量化是個什么操作呢?
其中,類似中括號那里就是取整的意思。可以看到,如果說沒有取整這個操作,完全就是減一個數,除一個數,再乘回來,再加回來,完全就沒有任何變化。但是因為有了這個取整,所以這中間就有了變化。
想象一下,如果在訓練過程中,采取了這么一個操作,那不就相當於提前就把量化的損失考慮進去了嗎?這樣等到inference的時候,精度下降就少的多了呀。
那么要把這個偽量化放在哪里呢?
那當然是放在inference的時候需要進行量化的位置,以論文中給出的圖來解析,
卷積的操作用公式來描述無非就是:
所以要量化的就是weights以及feature x。
這時候就有人提出疑問了,可是你看啊,人家給出的圖中是weights和激活值的偽量化啊,你怎么說是input的feature呢,可是如果你這樣想呢,除了第一層真正的輸入之外,剩下的層,上一層的activ輸出值不就是下一層的input值嗎,而且使用activ值有一個什么最大的好處呢?在最后一層將定點值反量化回到浮點值需要用到激活值的scale和zero_point(如果是非對稱量化的話)。
在訓練中融合BN到CONV
我們平時見到的最多的融合BN+CONV就是在inference的時候為了加速做的,但是你細想一下,你BN的參數在inference的時候怎么辦呢?如果inference的時候不融合,那么BN的參數你要怎么量化,如果融合了,那么weights的量化參數是根據融合前生成的啊,那你怎么能用呢?
所以解決方案就是,把BN融合在訓練階段就加進去,如下圖:
具體怎么做呢?
- 首先就的前向跑一遍,計算得到均值,方差等一系列BN的參數
- 然后,利用這些BN的參數,通過融合公式加到input和weights中去,將卷積公式變成真正的
其中
為了后續能夠更新原生 和 該過程中不僅需要保存 和 還需要保存 和 ,至於反向更新過程中,需要使用Straight Through Estimator(STE)來跳過偽量化過程中的round使得梯度可以正常回傳。
- 之后根據不同層的type添加input, weights和activation量化即可。目前我采用的方式是第一層卷積input, weights和activation量化都要有,其他層如route后面的卷積層同樣需要input量化,因為route的activation量化參數直接使用他的輸入層的activation量化參數即可;maxpool或者upsample都是添加activation量化即可。
Uint8推理實現
下面開始介紹定點推理,公式如下
由前面可知
為了保持量綱一致,令,
對上式進行簡單的變換
其中, 是唯一的浮點數, 因此采用 來代表, 和 shift 都是定點值,具體多大需要看精度需要,一般采用32位的值來表示。
-
在進入到正式的推理之前,首先看上式哪些值是常量可以提前計算出來,例如都是常量,其中1代表ft,2代表weights
-
進入到正式推理后,需要注意的問題就是溢出的問題,一般情況下為了防止這種情 況有兩種方式,一種就是使用一個shift來統計溢出的情況,另一種就是直接把輸出范圍擴大,例如8bit的乘加輸出到32bit。下面我們開始計算 及 ,為了能夠盡可能的探索優化速度的極限,gemm函數我們使用的是mkl中的cblas庫函數。
-
得到之后的最后一步操作就是激活,這部分在實際使用過程中也是關乎到量化精度的一個關鍵點。如果激活函數是類似softmax,tanh,swish等非線性函數的話,都要通過lookup table查表的方式,為了能夠盡快的實現,我這里選用的是tiny-yolov3,里面的激活函數都是leaky relu的線性激活函數。
-
其他層例如maxpool,route由於並不涉及到計算操作,因此直接將代碼轉成uint8的即可。
-
在最后一層yolo層的前面需要將uint8反量化回到float類型,方式如下:
后續改進
目前已經實現了yolov3-tiny的所有算子的實現,為了方便,目前使用relu6替代了原來的leakyrelu,包括conv, pooling, route, upsample,這些除了conv全部都是線性的算子,后續會繼續支持leaky relu, softmax, shortcut, elementwise add, concat等非線性算子。
量化performance
為了盡可能的不影響精度,我選擇在yolo層的上面一層conv層不進行量化。測試結果如下,可以看到
傳送門
Github鏈接:https://github.com/ArtyZe/yolo_quantization點擊閱讀原文,即可直接跳轉。◎本文為極市開發者「ArtyZe」原創投稿,轉載請注明來源。
◎極市「項目推薦」專欄,幫助開發者們推廣分享自己的最新工作,歡迎大家投稿。聯系極市小編(fengcall19)即可投稿~
本文僅做學術分享,如有侵權,請聯系刪文。 下載1在「計算機視覺工坊」公眾號后台回復: 深度學習,即可下載深度學習算法、3D深度學習、深度學習框架、目標檢測、GAN等相關內容近30本pdf書籍。
下載2在「計算機視覺工坊」公眾號后台回復: 計算機視覺,即可下載計算機視覺相關17本pdf書籍,包含計算機視覺算法、Python視覺實戰、Opencv3.0學習等。
下載3在「計算機視覺工坊」公眾號后台回復: SLAM,即可下載獨家SLAM相關視頻課程,包含視覺SLAM、激光SLAM精品課程。
重磅!計算機視覺工坊-學習交流群已成立
掃碼添加小助手微信,可申請加入3D視覺工坊-學術論文寫作與投稿 微信交流群,旨在交流頂會、頂刊、SCI、EI等寫作與投稿事宜。
同時也可申請加入我們的細分方向交流群,目前主要有ORB-SLAM系列源碼學習、3D視覺、CV&深度學習、SLAM、三維重建、點雲后處理、自動駕駛、CV入門、三維測量、VR/AR、3D人臉識別、醫療影像、缺陷檢測、行人重識別、目標跟蹤、視覺產品落地、視覺競賽、車牌識別、硬件選型、深度估計、學術交流、求職交流等微信群,請掃描下面微信號加群,備注:”研究方向+學校/公司+昵稱“,例如:”3D視覺 + 上海交大 + 靜靜“。請按照格式備注,否則不予通過。添加成功后會根據研究方向邀請進去相關微信群。原創投稿也請聯系。
▲長按加微信群或投稿
▲長按關注公眾號
3D視覺從入門到精通知識星球:針對3D視覺領域的知識點匯總、入門進階學習路線、最新paper分享、疑問解答四個方面進行深耕,更有各類大廠的算法工程人員進行技術指導。與此同時,星球將聯合知名企業發布3D視覺相關算法開發崗位以及項目對接信息,打造成集技術與就業為一體的鐵桿粉絲聚集區,近2000星球成員為創造更好的AI世界共同進步,知識星球入口:
學習3D視覺核心技術,掃描查看介紹,3天內無條件退款