TPU中的脈動陣列及其實現


深度學習飛速發展過程中,人們發現原有的處理器無法滿足神經網絡這種特定的大量計算,大量的開始針對這一應用進行專用芯片的設計。谷歌的張量處理單元(Tensor Processing Unit,后文簡稱TPU)是完成較早,具有代表性的一類設計,基於脈動陣列設計的矩陣計算加速單元,可以很好的加速神經網絡的計算。本系列文章將利用公開的TPU V1相關資料,對其進行一定的簡化、推測和修改,來實際編寫一個簡單版本的谷歌TPU,以更確切的了解TPU的優勢和局限性。

動手寫一個簡單版的谷歌TPU系列目錄

    谷歌TPU概述和簡化

    TPU中的脈動陣列及其實現

    神經網絡中的歸一化和池化的硬件實現

    TPU中的指令並行和數據並行

    Simple TPU的設計和性能評估

    SimpleTPU實例:圖像分類

    拓展

    TPU的邊界(規划中)

    重新審視深度神經網絡中的並行(規划中)

 

    本文將對TPU中的矩陣計算單元進行分析,並給出了SimpleTPU中32×32的脈動陣列的實現方式和采用該陣列進行卷積計算的方法,以及一個卷積的設計實例,驗證了其正確性。代碼地址https://github.com/cea-wind/SimpleTPU/tree/master/lab1

1. 脈動陣列和矩陣計算

    脈動陣列是一種復用輸入數據的設計,對於TPU中的二維脈動陣列,很多文章中構造了脈動陣列的寄存器模型,導致閱讀較為困難,而實際上TPU中的二維脈動陣列設計思路十分直接。譬如當使用4×4的脈動陣列計算4×4的矩陣乘法時,有

image

clip_image002

    如上圖所示,右側是一個乘加單元的內部結構,其內部有一個寄存器,在TPU內對應存儲Weight,此處存儲矩陣B。左圖是一個4×4的乘加陣列,假設矩陣B已經被加載到乘加陣列內部;顯然,乘加陣列中每一列計算四個數的乘法並將其加在一起,即得到矩陣乘法的一個輸出結果。依次輸入矩陣A的四行,可以得到矩陣乘法的結果。

    由於硬件上的限制,需要對傳播路徑上添加寄存器,而添加寄存器相對於在第i個時刻處理的內容變成了i+1時刻處理;這一過程可以進行計算結果上的等效。如下圖所示,采用z-1代表添加一個延時為1的寄存器,如果在縱向的psum傳遞路徑上添加寄存器,為了保證結果正確,需要在橫向的輸入端也添加一個寄存器(即原本在i進行乘加計算的兩個數均在i+1時刻進行計算)。給縱向每個psum路徑添加寄存器后,輸入端處理如右圖所示。(下圖僅考慮第一列的處理)

clip_image002[5]

    當在橫向的數據路徑上添加寄存器時,只要每一列都添加相同延時,那么計算結果會是正確的,但是結果會在后一個周期輸出,如下圖所示

clip_image002[7]

    上述分析可以,一個4×4的乘加陣列可以計算一組4×4的乘加陣列完成計算,而對於其他維度的乘法,則可以通過多次調用的方式完成計算。譬如(4×4)×(4×8),可以將(4×8)的乘法拆分乘兩個4×4的矩陣乘;而對於(4×8)×(8×4),兩個矩陣計算完成后還需要將其結果累加起來,這也是為何TPU在乘加陣列后需要添加Accumulators的原因。最終脈動陣列設計如下所示(以4×4為例)

clip_image002[11]

2. 脈動陣列的實現

    如第一節所述,可通過HLS構建一個脈動陣列並進行仿真。類似TPU中的設計,采用INT8作為計算陣列的輸入數據類型,為防止計算過程中的溢出,中間累加結果采用INT32存儲。由於INT32的表示范圍遠高於INT8,認為計算過程中不存在上溢的可能性,因此沒有對溢出進行處理。脈動陣列的計算結果數據類型為INT32,會在后文進行下一步處理。

    脈動陣列實現的關鍵代碼包括

1. Feature向右側移動

for(int j=0;j<MXU_ROWNUM;j++){
    for(int k=MXU_ROWNUM+MXU_COLNUM-2;k>=0;k--){
        if(k>0)
            featreg[j][k] = featreg[j][k-1];
        else
            if(i<mxuparam.ubuf_raddr_num)
                featreg[j][k] = ubuf[ubuf_raddr][j];
            else
                featreg[j][k] = 0;
    }
}

2. 乘法計算以及向下方移動的psum

for(int j=MXU_ROWNUM-1;j>=0;j--){
    for(int k=0;k<MXU_COLNUM;k++){
        ap_int<32> biasreg;
        biasreg(31,24)=weightreg[MXU_ROWNUM+0][k];
        biasreg(23,16)=weightreg[MXU_ROWNUM+1][k];
        biasreg(15, 8)=weightreg[MXU_ROWNUM+2][k];
        biasreg( 7, 0)=weightreg[MXU_ROWNUM+3][k];
        if(j==0)
            psumreg[j][k] = featreg[j][k+j]*weightreg[j][k] + biasreg;
        else
            psumreg[j][k] = featreg[j][k+j]*weightreg[j][k] + psumreg[j-1][k];
    }
}

    完成代碼編寫后可進行行為級仿真,可以看出整個計算陣列的時延關系

1. 對於同一列而言,下一行的輸入比上一行晚一個周期

Screenshot from 2019-06-11 01-05-31

2. 對於同一行而言,下一列的輸入比上一列晚一個周期(注意同一行輸入數據是一樣的)

Screenshot from 2019-06-11 01-06-08

3. 下一列的輸出結果比上一列晚一個周期

Screenshot from 2019-06-11 01-07-39

 

3. 從矩陣乘法到三維卷積

    卷積神經網絡計算過程中,利用kh×kw×C的卷積核和H×W×C的featuremap進行乘加計算。以3×3卷積為例,如下圖所示,省略Channel方向,拆分kh和kw方向分別和featuremap進行卷積,可以得到9個輸出結果,這9個輸出結果按照一定規律加在一起,就可以得到最后的卷積計算結果。下圖給出了3×3卷積,padding=2時的計算示意圖。按F1-F9給9個矩陣乘法結果編號,輸出featuremap中點(2,1)——指第二行第一個點——是F1(1,1),F2(1,2),F3(1,3),F4(2,1),F5(2,2),F6(2,3),F7(3,1),F8(3,2),F9(3,3)的和。

 clip_image002[13]

    下面的MATLAB代碼闡明了這種計算三維卷積的方式,9個結果錯位相加的MATLAB代碼如下所示

output = out1;
output(2:end,2:end,:) = output(2:end,2:end,:) + out2(1:end-1,1:end-1,:);
output(2:end,:,:) = output(2:end,:,:) + out3(1:end-1,:,:);
output(2:end,1:end-1,:) = output(2:end,1:end-1,:) + out4(1:end-1,2:end,:);
output(:,2:end,:) = output(:,2:end,:) + out5(:,1:end-1,:);
output(:,1:end-1,:) = output(:,1:end-1,:) + out6(:,2:end,:);
output(1:end-1,2:end,:) = output(1:end-1,2:end,:) + out7(2:end,1:end-1,:);
output(1:end-1,:,:) = output(1:end-1,:,:) + out8(2:end,:,:);
output(1:end-1,1:end-1,:) = output(1:end-1,1:end-1,:) + out9(2:end,2:end,:);

    而在實際的HLS代碼以及硬件實現上,部分未使用的值並未計算,因此實際計算的index和上述示意圖並不相同,具體可參考testbench中的配置方法。

4. 其他

    GPU的volta架構中引入了Tensor Core來計算4×4的矩陣乘法,由於4×4的陣列規模較小,其內部可能並沒有寄存器,設計可能類似第一節圖1所示。由於其平均一個周期就能完成4×4矩陣計算,猜測采用第一節中陣列進行堆疊,如下圖所示。

image

    一些FPGA加速庫中利用脈動陣列實現了矩陣乘法,不過不同與TPU中將一個輸入固定在MAC內部,還可以選擇將psum固定在MAC內部,而兩個輸入都是時刻在變化的。這幾種方式是類似的,就不再展開描述了。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM