Reformer: The Efficient Transformer


一、背景與算法介紹

   Transformer結構被廣泛應用與自然語言處理中,並且在許多任務上都產生了當前最好的效果。為了達到進一步的效果,研究人員已經開始訓練更大的Transformer模型。在某些報告的最大配置中,每層參數的數量超過了5億(0.5B),而層的數量增加到了64層。Transformer模型也用於越來越長的序列中,在一個單獨處理的樣本中,序列的長度能達到11k,也就是包含11000個tokens每個序列,甚至還有更長的序列存在。這種大規模的長序列模型,雖然產生了較好的效果,但由於資源的限制,使得這一趨勢正在打破NLP的研究發展。許多大型Transformer模型只能在大型工業研究實驗室中進行實際訓練,而這些並行訓練的模型甚至不能在單個GPU上進行微調,因為它每訓練一步,都需要多個加速器的硬件資源。

  這些大規模的Transformer模型真的需要這么多資源,還是因為不高效導致的呢?參考一下下面的i計算:單層的參數在5億個,需要內存約2GB;每一層的激活結果,為64K tokens, embedding size是1024,batch size是8,共計64k *1k *8=5億個floats,又需要2GB的內存。如果只是這種單層的內存需求,我們使用單個加速器就很容易滿足一個甚至長到64k的序列上。但是在多層上,內存的消耗就是驚人的:

 

  • 由於每一層需要存儲激活結果,所以N層網絡消耗的內存是單層的N倍。
  • Transformer每一層中間的前饋全連接網絡的維度dff要比注意力層的dmodel大的多,所以消耗的內存更多。
  • 序列長度為L的attention在時間和空間的復雜度都是O(L2),所以一個包含64K個tokens長的序列,都將會消耗巨大的內存。

本文引入的Reformer model將通過下面的技術解決這些問題

  • 可逆神經網絡,將只需要存儲一層的激活結果即可,N的因素消失了。
  • 分塊計算前饋全連接層,節省內存。
  • 采用局部敏感哈希技術,近似計算注意力,將時空開銷從O(L2)變為O(L)。

  我們學習這些技術,並且發現跟標准Transformer相比幾乎沒什么影響。可逆神經網絡確實改變了模型結構,但是通過實驗發現,也幾乎沒有什么影響。最后,注意力中的局部敏感哈希是一個更大的變化,可以影響訓練動態,這取決於所使用的並發哈希的數量。我們研究了這個參數,找到了既能高效使用,又能產生與全注意力相接近的效果。

  我們在合成任務上進行了實驗,一個是長度為64K的文本任務(enwik8),一個長度為12K的圖像生成任務(imagenet-64generation)。在這兩個實驗中都表明,Reformer 與標准Transformer結果相當,但運行得更快,特別是在文本任務上,具有一個數量級的內存效率提升。

 

 二、局部敏感哈希Attention

   Transformer的標准注意力計算公式如下:

  

具體詳細計算過程不再贅述,可參考Attention is all you need.

 內存高效的注意力:

  為了計算注意力機制的內存使用情況,我們集中看一下上述公式的注意力計算。先假設Q,K,V的shape都是[batch_size,length,dmodel],這里的主要關注點在QKT,其shape為[batch_size,length,length]。實驗中,我們訓練的序列長度為64K,這種情況下即便batch_size=1,QKT也是一個64k * 64K的矩陣,如果是32-bit floats的話,也將消耗內存16GB,這將阻擋Transformer在長序列上的使用。其實,QKT矩陣並不需要完全存儲在內存中,可以每次分別計算一個qi,計算一次 softmax(qi*KT/√dk) *V 存儲在內存中,然后在反向傳播的時候計算相應的梯度信息。這種方式可能效率有點低下,但卻是非常節省內存的

Shared-QK Transformer:

  在標准Transformer中,Q,K,V是由激活結果A分別通過三個線性層映射得到。但是這里引入了LSH attention,我們需要Q和K是相同的(備注:其實這里讓Q和K相同並不是LSH必須,LSH只需要讓Q、K變成單位向量即可,因為要在單位球面上進行相似查找,本文讓Q和K一樣只是為了方便批處理,加速計算),讓Q和K通過相同的線性映射即可實現該目的。我們稱這樣的模型為shared-QK Transformer,實驗結果表明共享Q、K並沒有影響Transformer的表現效果。

 LSH attention:

  正如上面介紹的,我們每一次只計算一個qi和K的結果,但是我們需要和K中的每一個元素都計算嗎?其實不是,我們只需要關心與qi相近的keys即可,K中的每一個元素從宏觀上理解就是一個word。假設K的長度為64K,也就是有64K個tokens,我們只需要考慮其中的32或者64個最近的keys,那效率將大大提升。如何得到這最近的keys呢?利用Locality sensitive hashing就可以實現,它的基本思路就是距離相近的向量能夠很大概率hash到一個桶內,而相距較遠的向量hash到一個桶內的概率極低。

 

  上圖是LSH的一個簡單示意圖,在示意圖的上部分,x和y不屬於近鄰,所以在三次隨意旋轉后,有兩次投影都不一樣;而在示意圖的下部分,x和y相距很近,在三次的隨意旋轉后,三次都投影都一樣,這就是LSH的基本原理。LSH原理的詳細解釋可以參考Locality Sensitive Hashing(局部敏感哈希)之cross-polytope LSH

   下面我們正式介紹LSH attention,首先重寫標准的attention公式,對於位置i的單個query的一次計算如下:

               

Pi或者:

                         

Pi就是位置i的query需要關注的tokens集合,h代表hash函數,z表示分區函數(即softmax中的規格化項,相當於somax中的分母),為了簡便,這里省去了√d

  對於一個長序列,為了便於統一批處理,修改計算公式如下:

通過公式可以看到,如果不屬於Pi的,都置為∞,相當於mask掉了,L是序列的長度。

  Hash桶容易產生不均勻的分配,跨桶處理是比較困難的;另外,一個桶內的queries和keys數量不一定相等,事實上,有可能存在桶中只有queries而沒有keys的情況。為了避免這種情況,首先通過kj=q/ ||qj|| 確保h(kj)=h(qj);其次,我們外部根據桶號、桶內部依據序列位置對queries進行排序,排序后定義一個置換i->si。排序后的注意力矩陣,同一個桶的將聚集在對角線附近,方便批量處理,提升速度,這點就跟上述說的Shared-QK一樣,如下圖c-d:

我們可以遵循一種批處理方法,其中m個連續查詢的塊(排序后)相互關聯,后面的塊往前看一個塊。按照我們之前的符號,設置如下:

 

在實際中我們設置m=2L / nbuckets,L是序列的長度,每個桶的平均大小是L / nbuckets,所以我們前提假設一個桶成長為平均大小的2倍的概率是極低的。LSH attention的整個處理流程總結在下圖中:

多輪LSH attention:

  單個hash函數,總不可避免的會出現個別相近的items卻被分到不同的桶里,多輪hash {h(1),h(2),...}可以減少這種情況的發生:

這里的多輪 LSH attention可以並行執行。

 

三、可逆Transformer

 可逆殘差網絡:

  可逆殘差網絡的主要思想是:在反向傳播計算的時候,只使用模型參數就可以從下一層的激活結果中恢復任何給定層的激活結果,從而不用保存中間層的激活結果。標准的殘差層從輸入x到輸出y的映射公式是:y=x+F(x),但是可逆層的輸入輸出都是成對的:(x1,x2)->(y1,y2),計算公式如下:

逆向計算公式如下:

 

  可逆殘差網絡細節可以參考大幅減少GPU顯存占用:可逆殘差網絡(The Reversible Residual Network)

 可逆Transformer

  我們將可逆殘差網絡的思想應用到Transformer中,在可逆塊中結合了自注意力層和前饋網絡層。結合上面的可逆殘差公式,F函數變成了自注意力層,G函數變成了前饋網絡層,注意的是每層的歸一化處理放在了殘差塊里面。

  

  可逆Transformer不需要在每一層中存儲激活結果,在后面實驗部分,我們對比使用了相同數量的參數,其表現與標准Transformer一樣。

 

分塊:
每一層Transformer中前饋網絡所用的中間向量維度dff=4k甚至更高維度,依然非常占用內存;然而,一個序列中各個tokens在前饋網絡層的計算是相互獨立的,所以這部分計算可以拆分為c個組塊:

 這一層通常是對所有位置並行操作批量完成的,但是一次只對一個塊執行操作可以減少內存;可逆計算和反向傳播也是分塊進行的。對於字典比較大的模型,在計算 log-probabilities輸出和loss的時候,也是一次計算一個組塊。

 

四、實驗分析

   通過實驗來展示上面介紹的技術效果,我們逐個分析上面的技術,從而能夠更清晰的看出哪種組合能夠影響實驗結果。我們在 imagenet64和enwik8-64K 任務上進行實驗,這里使用3層模型進行實驗,以便與標准Transformer進行對比。參數設置,dmodel = 1024, dff = 4096, nheads = 8。

Shared-QK效果

  共享QK通過設置kj=q/ ||qj||實現,並且阻止注意力放到自己token上面,除非沒有上下文。從下圖實驗結果可以看出共享QK機制並沒有比標准注意力機制效果差。

同時,在enwik8-64K實驗上,似乎訓練的速度更快一些。

可逆層的效果

  這里還是用標准Transformer跟可逆網絡層對比,二者所使用的參數基本一樣,學習曲線圖如下:

二者曲線基本一致,這說明可逆網絡結構在節省內存的前提下,並沒有損傷精度。

LSH attention in Transformer

  相比全注意力機制,LSH注意力是一個近似的方法,從下面的實驗圖可以看出隨着hash函數的增加,精確度也越來越高。

從圖中可以看出,在nrounds = 8的時候,精確度已經跟全注意力機制相匹敵了;但是hash函數越多,計算代價就越高,所以這個超參數可以根據實際計算資源進行調整。

實驗也對比了不同注意力機制的速度,如下圖:

可以看出,隨着序列長度的不斷增加,標准注意力機制變得越來越慢,而LSH注意力機制基本變化不大,提速效果非常明顯。

 

參考鏈接:

 論文:https://arxiv.org/abs/2001.04451

 github:https://github.com/google/trax/tree/master/trax/models/reformer

             https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html

    https://openreview.net/forum?id=rkgNKkHtvB


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM