Swin Transformer
Swin transformer是一個用了移動窗口的層級式(Hierarchical)transformer。其像卷積神經網絡一樣,也能做block以及層級式的特征提取。本篇博客結合網上的資料,對該論文進行學習。
摘要
本篇論文提出了一個新的Transformer,稱作Swin Transformer,其可以作為計算機視覺領域的一個通用的骨干網絡。這是因為ViT這篇論文僅僅是做了分類的任務, 而Swin Transformer在計算機視覺的各個領域都有取代CNN的潛力。但是直接把Transformer用到視覺方面會有兩個問題:一個是視覺實體的尺度有很大的變化(比如無人駕駛任務中一張街景圖片,此時代表同樣語義的一個詞其對應的實體可能有各種各樣的大小,這種問題在NLP就不曾出現),另一個是圖像的像素與文本中的字相比分辨率非常高,如果以像素點作為基本單位的話序列的長度就會迅速增加。針對第二個問題,目前的解決方案要么是以特征圖作為輸入,要么是把圖片打成patch,要么是把圖片畫成一個個小窗口,在窗口里做自注意力。針對上述兩個問題,作者 提出了一種包含移動窗口、具有層級設計的Transfoer——Swin Transformer。這種設計讓兩個相鄰的窗口之間產生了交互(cross-window connection)。同時作者提到,這種分層結構可以靈活地在不同的尺度上建模,並且計算復雜度隨着圖像大小的增大而線性增長(非平方級別增長)。由於這種分層的結構,Swin Transformer像卷積神經網絡一樣擁有了多尺度的特征,因此可以被應用到下游任務中。
引言
上圖是Swin Transformer和ViT的一個對比。ViT每一層都是16倍的下采樣率,不適用於預測密集型任務。同時其對於多尺度特征的把握會變弱,而對於檢測和分割的任務,多尺度的特征是非常重要的。且其自注意力始終是在整張圖上進行,即是一個全局建模,其計算復雜度與圖像大小成二次方關系。因此,Swin Transformer借鑒了CNN的很多設計理念以及其先驗知識:小窗口內算自注意力(認為同一個物體會出現在相鄰的地方,因此小窗口算自注意力其實是夠用的,而全局自注意力實際上有一些浪費資源)。CNN之所以能抓住多尺度的特征是因為池化這個操作(能增大每一個卷積核的感受野),因此Swin Transformer也提出了一個類似池化的操作,把相鄰的小patch合成一個大的patch。
Swin Transformer最關鍵的一個設計元素就是移動窗口,使得窗口與窗口之間可以進行交互,再加上之后的patch merging,合並到transformer最后幾層的時候每一個patch本身的感受野就已經很大了,再加上移動窗口的操作,就相當於實現了全局自注意力。
方法
整體流程
假設有一張ImageNet標准尺寸圖片\(224\times 224\times 3\),首先將其打成patch(\(4\times 4\)而非\(16\times 16\)),得到的圖片尺寸是\(56\times 56\times 48\),其中\(48=4\times 4\times 3\)(3為通道數)。之后是Linear Embedding,把向量的維度變成Transformer能夠接受的值,超參數設置為C。Swin-T的C為96,此處為\(56\times 56\times 96\)。之后前兩個56拉成一個維度\(3136\)(即seq的長度),之后的96變成了每一個token的向量的維度。而3136太長了,transformer不可接受,因此Swin Transformer Block使用了基於窗口的自注意力計算方法。對於每個窗口,其默認只有\(7\times 7 = 49\)個patch,序列長度就變成了49,相比3136而言大大減小。如果對於transformer不做約束的話,輸入序列長度是多少,則輸出序列長度就是多少,因此經過第一個Block之后,輸出的尺寸還是\(56\times 56\times 96\)。之后為了實現層級的結構,需要加入類似CNN中池化的操作,因此就有了圖中的Patch Merging。這里想要下采樣兩倍,因此經歷了如下圖的過程(最后將四個張量在通道維度拼接),向量變為\(28\times 28\times 384\)(其中\(384=96\times 4\)):
CNN中池化后通道數往往翻倍,因此這里也想要讓其翻倍,而此時通道維是4C而非2C,因此需要用\(1\times 1\)卷積核把通道維將為\(2C\),所以經過第一個Patch Merging后向量大小為\(28\times 28\times 192\)。之后再經歷一個transformer block其大小不變,因此第二個stage結束后大小為\(28\times 28\times 192\)。以此類推,第四個stage結束后輸出大小為\(7\times 7\times 768\)。
在最后,直接使用全局平均池化將\(7\times 7\)變成1來做分類。然而Swin Transformer並非只做分類,因此這一部分可以修改,作者也就沒有畫出來。
基於自注意力的移動窗口
作者首先介紹了這樣做的動,即全局自注意力計算會導致平方級別的計算復雜度,進而提到使用窗口來做自注意力。原來的圖片會被平均分成沒有重疊的窗口。以第一層輸入為例,其大小為\(56\times 56\times 96\),將其切分為\(8\times 8\)個窗口,每個窗口內有\(M\times M\)個patch(文章中M默認為7)。自注意力都是在小窗口完成的,序列長度永遠為49。對於計算復雜度,作者進行了如下估計(假設划分為\(h\times w\)個patch):

第一個公式是標准的多頭自注意力,計算過程如下:
而對於使用窗口的多頭自注意力,一個窗口內計算的還是多頭自注意力,可以直接套用前一個公式,對於每個窗口,input大小變為\(M^2\times C\)(h和w變為m和m),因此一個窗口的計算復雜度為\(4M^2C^2+2M^4C\),而總共有\(\frac{h}{M}\times \frac{w}{M}\)個窗口,乘起來就是\(4hwC^2+2M^2hwC\)。
對於這種方式,作者說雖然這很好地解決了內存和計算量的問題,但是缺少了窗口和窗口之間的通信,會限制模型的表達能力。因此作者提出了移動窗口的方式。具體的過程參看下圖:

因為每次都是先做一次基於窗口的自注意力,再做一次基於移動窗口的自注意力,所以整體流程的圖里Transformer Block的數量都是偶數(2、2、6、2)。
為了提高移動窗口的計算效率,作者使用了掩碼,以及使用的是相對位置編碼而非絕對位置編碼。

對於上圖而言,雖然已經做到了窗口之間的相互通信,但經過一次移動后從四個窗口變為了九個窗口,且窗口大小不一,這樣就沒法把這些窗口壓為一個batch直接去做自注意力了(窗口大小不一樣,除了batch維沒法合並)。一種解決方式是給非最大尺寸的窗口做padding,但是這樣的計算量就會大大增加。因此作者提出了一種循環移位(cyclic shift)的方式:
這樣窗口的數量固定了(圖中為4),計算復雜度也就固定了。但是這樣的話,有的窗口中的元素原本不在一起,本不應該做自注意力,這里作者提出了比較巧妙的掩碼方式,參照下面的手繪圖:
如果是選擇右上角窗口,則展開后的樣子是條紋狀的。作者給出的四個窗口的掩碼模版如下:
更正一下,這里用-100作為掩碼值是因為自注意力值較小,這是由於LN層的作用以及比的約束例如weight decay(一般模型中間的輸入輸出都比較小以防止過擬合)。
計算完之后,還需要進行逆向的循環移位來還原回去(還需要保持原來圖片的相對位置,不能破壞原來圖像的語義信息)。
這一部分的最后,作者介紹了一下Swin Transformer的幾個變體,對比了Swin Transformer全家桶與ResNet全家桶的復雜度。變量主要為向量通道維C以及每個Stage有多少個Transformer Block。
實驗
作者分別使用ImageNet-1K和ImageNet-22K兩個數據集做預訓練,測試均在ImageNet-1K上進行(在22K數據集上預訓練好的模型需要做fine tune)。
總結
~~~