Xilinx器件INT8優化方法的HLS示例


1、引言

Xilinx器件自帶的DSP48E乘法器能夠實現18x27位的乘法和高達48位的累加,關於Xilinx的DSP如何實現INT8的優化,官方早在2016年發布的WP486白皮書中已經給出了明確的指引。其設計思路是將兩組具備同一系數的INT8乘法計算經過移位拼接,實現由單個DSP完成兩組INT8的乘法和累加,最終實現1.75倍的性能提升。官方已經給出了很好的示例,筆者在此只是分享個人的一些感悟,以及該方法所對應的HLS實現方法。說實話,白皮書在2016年已經給出,現在已經過去6年了,再繼續研究CNN在FPGA端側的加速已經沒有任何意義了,以后是ASIC的天下,而且我的研究也只是重復造輪子,官方提供的FINN和Brevitas工具鏈已經能夠很好的實現FPGA對常用CNN算子的支持,直接拿來可以應對幾乎所用的場景。唉,沒意義~

2、DSP48E工作示例

此處醫用WP486官方文檔的圖片1,如下圖所示。其中輸入為A、B、C和D,實現的輸出為B*(A+D)+C,其中C可以是上一次結果的輸出,從而實現累加器的功能,根據其中對應的核心乘法器位寬為18x27,即輸入數據B的位寬不應超過18bit,輸入數據A和D的邏輯運算結果不應超過27bit。因此如果實現通用的16bit與16bit乘法,一般需要消耗1個DSP;而如果是兩組8bit與8bit的數據相乘,如果能夠先對數據完成移位拼接,然后使得乘法運算后的結果不對另一組數據的乘法運算結果造成影響,那么即可實現INT8計算的優化。

這里優化的前提是乘法運算的要有同一個乘數,即z0=a*c和z1=b*c。其核心優化思想是將y0和y1移位拼接為一個具有更大位寬的Y,計算x*Y,然后再對計算結果進行拆分,即得到z0和z1。 計算過程對應wp486里面的圖2和圖3。

 

 

 3、HLS實現步驟與程序

那么針對上述所提出的優化思路,我們可以在HLS中對其進行仿真驗證,結合圖3中的設計思想,核心步驟在於數據打包與計算結果的拆分。我們以z0=a*c和z1=b*c作為核心示例,同時以CNN中常見的特征圖復用與權重復用兩種計算策略進行說明。

在數據打包的過程中,若a的數據位寬為W_a;b的數據位寬為W_b;c的數據位寬為W_c;則打包后的數據位寬為W_a+W_b+W_c,且該結果不能超過27bit,以及W_c不能超過18bit,其主要原因在於DSP48的輸入為18bitx27bit。

因此筆者針對上述過程,只優化乘法部分,而對累加部分不作具體概述,因此做出如下的驗證測試用例:

  1 // @Time    : 2021.12.20
  2 // @Author  : wuruidong
  3 // @Email   : wuruidong@hotmail.com
  4 // @FileName: MAC_8bit_tb.cpp
  5 // @Software: Vivado HLS 2018.3
  6 // @Cnblogs : https://www.cnblogs.com/ruidongwu
  7 
  8 #include <hls_half.h>
  9 #include <ap_fixed.h>
 10 #include <iostream>
 11 using namespace std;
 12 
 13 //iii
 14 template<int A_N, int W0_N, int W1_N>
 15 ap_int<A_N+W0_N+A_N+W1_N> MUL_MAC(ap_int<A_N> A, ap_int<W0_N> W0, ap_int<W1_N> W1)
 16 {
 17     ap_int<W0_N+A_N+W1_N> W;
 18     W = (W0, ap_uint<A_N+W1_N>(0)) + ap_int<W0_N+A_N+W1_N>(W1);
 19 
 20     ap_int<A_N+W0_N> r0;
 21     ap_int<A_N+W1_N> r1;
 22 
 23     (r0, r1) = A*W;
 24 
 25     r0 = r0+r1[A_N+W1_N-1];
 26 
 27     return (r0,r1);
 28 }
 29 
 30 //uuu
 31 template<int A_N, int W0_N, int W1_N>
 32 ap_uint<A_N+W0_N+A_N+W1_N> MUL_MAC(ap_uint<A_N> A, ap_uint<W0_N> W0, ap_uint<W1_N> W1)
 33 {
 34     ap_uint<W0_N+A_N+W1_N> W;
 35     W = (W0, ap_uint<A_N>(0), W1);
 36 
 37     ap_uint<A_N+W0_N> r0;
 38     ap_uint<A_N+W1_N> r1;
 39 
 40     (r0, r1) = A*W;
 41 
 42     //r0 = r0+r1[A_N+W1_N-1];
 43 
 44     return (r0,r1);
 45 }
 46 
 47 //uii
 48 template<int A_N, int W0_N, int W1_N>
 49 ap_int<A_N+W0_N+A_N+W1_N> MUL_MAC(ap_uint<A_N> A, ap_int<W0_N> W0, ap_int<W1_N> W1)
 50 {
 51     ap_int<W0_N+A_N+W1_N> W;
 52     W = (W0, ap_uint<A_N+W1_N>(0)) + ap_int<W0_N+A_N+W1_N>(W1);
 53 
 54     ap_int<A_N+W0_N> r0;
 55     ap_int<A_N+W1_N> r1;
 56 
 57     (r0, r1) = A*W;
 58 
 59     r0 = r0+r1[A_N+W1_N-1];
 60 
 61     return (r0,r1);
 62 }
 63 
 64 //iuu
 65 template<int W_N, int A0_N, int A1_N>
 66 ap_int<W_N+A0_N+W_N+A1_N> MUL_MAC(ap_int<W_N>  W, ap_uint<A0_N> A0, ap_uint<A1_N> A1)
 67 {
 68     ap_uint<A0_N+W_N+A1_N> A;
 69     A = (A0, ap_uint<W_N>(0), A1);
 70 
 71     ap_int<W_N+A0_N> r0;
 72     ap_int<W_N+A1_N> r1;
 73 
 74     (r0, r1) = W*A;
 75 
 76     r0 = r0+r1[W_N+A1_N-1];
 77 
 78     return (r0,r1);
 79 }
 80 
 81 int main(void)
 82 {
 83     ap_int<17> r0, r1;
 84 
 85     ap_uint<8> a=255;
 86     ap_int<9> w0=255, w1=-255;
 87     (r0, r1) = MUL_MAC<8,9,9>(a, w0, w1);
 88     cout<<"uii"<<endl;
 89     cout<<r0.to_int()<<endl;
 90     cout<<r1.to_int()<<endl;
 91 
 92     ap_int<8> ax=-128;
 93     (r0, r1) = MUL_MAC<8,9,9>(ax, w0, w1);
 94     cout<<"iii"<<endl;
 95     cout<<r0.to_int()<<endl;
 96     cout<<r1.to_int()<<endl;
 97 
 98     ap_uint<8> a0=255, a1=255;
 99     ap_int<9> w=-255;
100     (r0, r1) = MUL_MAC<9,8,8>(w, a0, a1);
101     cout<<"iuu"<<endl;
102     cout<<r0.to_int()<<endl;
103     cout<<r1.to_int()<<endl;
104 
105     ap_uint<8> x=255;
106     ap_uint<9> y0=511, y1=511;
107     ap_uint<17> z0, z1;
108     (z0, z1) = MUL_MAC<8,9,9>(x, y0, y1);
109     cout<<"uuu"<<endl;
110     cout<<z0.to_int()<<endl;
111     cout<<z1.to_int()<<endl;
112 
113     return 0;
114 }

上述HLS代碼使用到了C++中的模塊函數與函數重載功能。根據實際的INT8乘法計算過程,可以分為有符號乘有符號(iii)、無符號乘無符號(uuu)、有符號乘無符號(iuu)、無符號乘有符號(uii)四種類型。除了uuu不用考慮二進制補碼的問題,其他均需要考慮。在此處,筆者考慮到CNN量化感知訓練中常用的TFLite格式文件,權重數據通常使用8bit無符號與8bit的固定偏移來表示,即TFLite中權重數據通常使用9bit有符號數來表示(相關鏈接點我),因此在實際的測試中使用9bit與8bit的乘法運算作為示例。

4、使用說明

在實際的使用過程中:

①一般第一層卷積為8bit有符號或無符號的RGB數據,那么CNN加速中通常使用特征圖復用,即iii或uii模式;

②隨着第一層CNN網絡計算結束,Relu函數的調用使得特征圖通常由原來的8bit有符號數轉換為7bit無符號數,若使用特征圖復用,此時使用uii模式;

③若CNN網絡的中間層沒有使用Relu函數,或者使用了LeakyRelu、RLU、Tanh等,則特征圖表示為8bit有符號數,此時使用iii模式;

③再隨着CNN網絡的加深,為了減少片上存儲的消耗和外部DDR交互帶寬的使用,CNN加速轉變為權重復用,此時使用iuu模式;

④一直到CNN網絡的最后一層,如果使用FC分類器,有可能出現權重數據均為無符號數的情況,那么此時可以使用uuu模式(可選,極少出現)。

以上的INT8優化方法在使用的時候,必須要保證為兩組具有同一系數的乘法操作,即對應wp486的圖7。

 

5、總結

結合Xilinx官方給出的INT8優化方法,可以很輕松的實現INT8數據的算力翻倍。同理結合上述設計思路,同樣能夠實現其他bit的計算優化,例如若特征圖為8bit,權重為3bit,則單個DSP可以實現3組乘法優化;若特征圖為4bit,權重為4bit,加上使用特征圖輸入拼接打包為12bit,則單個DSP可以實現4組乘法優化;當然也支持其他數據位寬組合,只要特征圖打包后位寬小於18bit,權重打包后位寬小於27bit,均可實現乘法優化。

 

參考資料:WP486 - Deep Learning with INT8 Optimization on Xilinx Devices White Paper


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM