Swin Transformer
論文地址:https://arxiv.org/abs/2103.14030
項目地址:https://github.com/microsoft/Swin-Transformer
摘要
本文提出了一種新的 vision Transformer,稱為Swin Transformer,它可以作為計算機視覺的通用backbone。Transformer從語言到視覺的轉換面臨很大的挑戰,它主要來自於兩個領域之間的差異,例如視覺實體的規模變化很大,圖像中的像素與文本中的單詞相比分辨率很高。為了解決這些差異,我們提出了一個 hierarchical Transformer(層次 Transformer),其表征是用shifted window計算的。滑動窗口方案通過將自注意計算限制在非重疊的局部窗口上,同時允許跨窗口連接,從而提高了效率。這種hierarchical體系結構具有在各種尺度下建模的靈活性,並且具有與圖像大小相關的線性計算復雜性。Swin-Transformer的這些特性使其能夠兼容廣泛的視覺任務,包括圖像分類(ImageNet-1K上的准確率為86.4 top-1)和密集預測任務,如目標檢測(COCO test-dev上的58.7 box AP和51.1 mask AP)和語義分割(ADE20K val上的53.5 mIoU)。它的性能超過了以前的先進水平,COCO上的box-AP和mask-AP分別為+2.7和+2.6,ADE20K上的mask-AP和+320萬,顯示了基於Transformer的模型作為視覺backbone的潛力。
1、介紹
在計算機視覺建模過程中,CNN網絡取得了優良的性能表現,過去幾年基於CNN網絡做了大量的工作;而在NLP領域,發展至今,Transformer越來越成為baseline,它對於處理長期依賴有較好的表現。它在語言領域的巨大成功促使研究人員研究它對計算機視覺的適應性,最近它在某些任務上顯示了有希望的結果,特別是圖像分類[19]和聯合視覺語言建模[46]。
在本文中,我們試圖擴展Transformer的適用性,使它可以像CNNs一樣作為計算機視覺的通用backbone。我們觀察到,將其在語言領域的高性能轉移到視覺領域的重大挑戰可以解釋為兩種模式之間的差異。其中一個差異涉及規模。與作為語言Transformer中處理的基本元素的單詞tokens不同,視覺元素在規模上可能有很大的差異,這是一個在目標檢測等任務中受到關注的問題[41,52,53]。目前基於Transformer的方法,tokens是固定大小的,這個特性並不能適用於視覺。另一個差異主要是,相比於文本,視覺的像素具有更高的分辨率。如實例分割我們需要在像素級處理、計算,這樣 self-attention 的計算復雜度就非常高了。為了克服這個問題,我們提出了通用Transformer backbone: (Swin Transformer),該算法構造層次化特征映射,計算復雜度與圖像大小呈線性關系。如下圖所示:
Swin-Transformer構造了一個層次表示,從小尺寸的像素塊(用灰色表示)開始,逐漸合並更深層次的像素塊。有了這些分層特征映射,Swin-Transformer模型可以方便地利用先進的技術進行密集預測,如特征金字塔網絡(FPN)[41]或U-Net[50]。線性計算復雜度是通過在分割圖像的非重疊窗口(紅色輪廓)內局部計算自我注意來實現的。每個窗口中的像素塊數是固定的,因此復雜度與圖像大小成線性關系。這使得Swin-Transformer 作為backbone可以適應各種視覺任務。而之前用於視覺的Transformer技術只使用了單層特征圖,且擁有二次復雜度。
Swin Transformer的一個關鍵設計元素是它在連續的自我關注層之間的窗口分區移動,如圖2所示。

移動窗口(shifted window)橋接了前一層的窗口,提供了它們之間的連接,顯著增強了建模能力(見表4)。這種策略對於延遲也是有效的:一個窗口中的所有查詢像素塊共享相同的key,這有助於硬件中的內存訪問。我們的實驗表明,所提出的移動窗方法比滑動窗方法有更低的延遲,但在建模能力上是相似的。
Swin Transformer實現了更強的性能表現:在延遲相似的前提下,它比ResNe(X)t models、ViT / DeiT 都要好!實現了\(58.7\%\)的box AP,\(51.1\%\)的mask AP,測試數據 COCO test-dev set,相比於之前的SOTA模型分別提高了\(2.7 P\)、\(2.6P\);mIoU值提高了3.2,在ImageNet-1K上達到\(86.4\%\)的Top-1 accuracy。
我們相信,一個跨計算機視覺和自然語言處理的統一體系結構可以使這兩個領域都受益,因為它將促進視覺和文本信號的聯合建模,並且來自這兩個領域的建模知識可以更深入地共享。我們希望Swin Transformer在各種視覺問題上的出色表現能夠推動社區加深這種信念,並鼓勵視覺和語言信號的統一建模。
2、相關工作
2.1 CNN及其變體
AlexNet ---> VGG、GoogleNet、ResNet、 DenseNet、 HRNet、EffificientNet。
基於這些演進產生了著名的:
- 深度可分離卷積;
- 形變卷積;
2.2 基於backbone結構的自注意力機制
self-attention layers目前被學者熱衷與替換ResNet中的某個卷積,這里主要是基於局部窗口優化,它們確實是提高了性能。但是提高性能的同時,也增加了計算復雜度。我們使用shift windows替換原始的滑動窗口,它允許在一般硬件中更有效地實現。
2.3 Self-attention/Transformers 作為 CNNs 的補充
見名知義:在傳統的backbone后面添加Self-attention/Transformers結構。我們的工作探索了Transformer對基本視覺特征提取的適應,是對這些工作的補充。
2.4 基於Transformer的backbone
與Swin Transformer最相關的工作是ViT和它的繼承者。ViT的開創性工作是在不重疊的中等尺寸圖像塊上直接應用一種Transformer結構進行圖像分類。與卷積網絡相比,它在圖像分類上實現了令人印象深刻的速度精度折衷。但是ViT需要大量的圖片才能訓練好網絡,DeiT改進了訓練策略,使得需要的圖片集 變小 。雖然ViT在圖片分類上有所提高,但是它不適合於高分辨率圖片,因為它的復雜度是圖片大小的二次方。將ViT模型應用於目標檢測和語義分割等稠密視覺任務中的直接上采樣或解卷積算法,效果相對較差。而我們的工作也是魔改了ViT,使其在圖片分類任務上進一步提升。根據經驗,我們發現我們的Swin-Transformer架構可以在這些圖像分類方法中實現最佳的速度-精度折衷,盡管我們的工作側重於通用性能,而不是專門針對分類。有其他學習和也在做多尺度分辨率的融合工作,但是其復雜度還是二次,而我們的復雜度是線性復雜度。我們的模型是兼顧了模型性能與速度,在COCO目標檢測和ADE20K語義分割上達到新的SOTA。
3、方法
3.1 整體架構
圖3給出了Swin-Transformer體系結構的概述,說明了微型版本(Swin-T)。

它首先通過像ViT一樣的分片模塊將輸入的RGB圖像分片成不重疊的像素塊。每個像素塊被視為一個“token”,其特征被設置為原始像素RGB值的串聯。我們使用的像素塊是\(4 \times 4\)的size,所以其特征維度為\(4 \times 4 \times 3 = 48\)。在這個原始值特征上應用一個線性嵌入層,將其投影到任意維(表示為\(C\))。
在stage1中,幾個Swin Transformer blocks算子被應用於這些像素塊上。這些 Transformer blocks保持了\(\frac{H}{4} \times \frac{W}{4}\)的tokens數量,並且伴隨線性的嵌入層。
stage2中,為了產生一個層次化的表示,由於像素塊的合並使得tokens的數量減少了。第一次patch merging layer合並了\(2 \times 2\)領域內的像素塊,並且使用一個線性層在\(4C\)的特征上進行合並。這個操作減少了\(2 \times 2 = 4\)倍的tokens,並設置輸出的維度為\(2C\)。這里的Transformer blocks應用於特征變換后,tokens的數量變為\(\frac{H}{8} \times \frac{W}{8}\)。這第一個像素塊融合和特征變換被稱為stage2。這種操作進行疊加產生了stage3、stage4,如圖所示,tokens的數量分別為:\(\frac{H}{16} \times \frac{W}{16}\)、\(\frac{H}{32} \times \frac{W}{32}\)。這些階段共同產生一個層次表示,具有與典型卷積網絡相同的特征圖分辨率,如VGG [51] and ResNet [29]。結果表明,該體系結構可以很方便地取代現有方法中的backbone,用於各種視覺任務。
3.1.1 Swin Transformer block
Swin Transformer塊使用了shifted windows替換了傳統的多頭注意力機制MSA,如上圖3(b)。Swin Transformer block是由基於MSA的shifted windows組成,它的前面有LN(LayerNorm)層,后面有LN + MLP包圍,且有殘差進行連接。

3.2 基於自注意力的Shifted Window
標准的Transformer架構適應於圖像分類,主要采用了相對位置編碼的全局自注意力機制,而全局計算的復雜度是二次的,表現為tokens的數量。這在很多視覺任務中都會帶來速度損失,且在高分辨率下表現出很強的不適應性。
3.2.1 非重疊窗口的自注意力
為了高效地計算,我們采用局部窗口。這些窗口是均勻排列,且相互不重疊。假設窗口包含\(M \times M\)個像素塊,則 global MSA 在\(h \times w\)的圖像上的計算復雜度為:
這里MSA的復雜度是hw的二次方,而\(M^{2}\)是遠小於\(hw\)的,所以它是\(hw\)的一次復雜度。
3.2.2 連續塊的移位窗口划分
基於窗口的自注意模塊缺乏跨窗口的連接,這限制了它的建模能力。為了在保持非重疊窗口計算效率的同時引入跨窗口連接,我們提出了一種 shifted window 划分方法,該方法在連續的Swin-Transformer塊中交替使用兩種划分配置。
如圖2所示,第一個模型使用了正則窗口划分策略。從左上角開始,將\(8 \times 8\)的像素塊划分為\(M \times M (M=4)\)個\(2 \times 2\)的像素塊。下一個模型的策略是shifted。對上一層划分的窗口進行移動:向左上角移動\(\left ( \left \lfloor \frac{M}{2} \right \rfloor , \left \lfloor \frac{M}{2} \right \rfloor \right )\)個\(2 \times 2\)的像素塊。利用移窗划分方法,連續的Swin-Transformer塊的計算公式為:
上面的公式對應了圖3(b)所示的結構。
移動窗口划分策略,實現了相鄰的非重疊窗口的連接,經過實驗我們發現它對於圖像分類、目標檢測、語義分割是高效的!參考Table 4.
注意: W-MSA中的像素塊特征數是一致的,但是SW-MSA它可不是一致的,這個怎么計算呢?
3.2.3 shifted策略的高效batch計算
shifted操作使得像素塊patches的個數\(\left \lceil \frac{h}{M} \right \rceil \times \left \lceil \frac{w}{M} \right \rceil\)從變為\(\left (\left \lceil \frac{h}{M} \right \rceil + 1 \right ) \times \left (\left \lceil \frac{w}{M} \right \rceil + 1 \right )\)。如圖2所示。且這里部分窗口的大小不是\(M \times M\)。最簡單的方法是直接將所有窗口padding到同樣的大小。如果正則策略划分的較小,如上面的\(2 \times 2\),那么將增加計算量。然而我們提出了一種批量的高效計算方法:循環向左上角移動(cyclic-shifting),如下圖所示:
cyclic-shifting實際上就是將移動造成的非\(M \times M\)像素塊合並為\(M \times M\)像素塊,或者你可以理解為之前是窗口在移動,而現在是特征圖在移動,超過左上角window的部分在右下腳進行填充!經過cyclic-shifting的調整,實際上每個像素塊的大小又一致了,同時我們還實現了不同的patch之間的信息融合,且patches的個數沒有發生變化。
3.2.4 相對位置偏置
在自注意力模塊的計算中,我們引入相對位置偏置到每一個頭的計算:
其中\(Q,K,V \in \mathbb{R}^{M^{2} \times d}\)分別標識查詢矩陣、鍵矩陣、值矩陣;\(d\)是查詢矩陣或鍵矩陣的dimension,\(M^{2}\)是窗口內的像素塊個數。
Note: 這里的像素塊是基本像素塊單元,即上文的\(2 \times 2\)像素塊,\(M^{2}\)即窗口內的\(2 \times 2\)像素塊個數。
因為相對位置的范圍是\(\left[ -M +1, M-1\right]\),我們參數化了一個小尺寸的bias矩陣\(\hat{B} \in \mathbb{R}^{\left( 2M-1\right)\times \left( 2M-1\right)}\),並且我們的\(B\)是從\(\hat{B}\)中取的一個token。
如表4所示,我們觀察到與沒有此偏差項或使用絕對位置嵌入的對應項相比有顯著改進。如[19]中所述,進一步向輸入中添加絕對位置嵌入會略微降低性能,因此在我們的實現中不采用這種方法。
預訓練中學習到的相對位置bias矩陣也可用於初始化模型\(\hat{B}\),以便通過雙三次插值以不同的窗口大小進行微調[19,60]。
3.3 結構變體
作者設置的模型架構有:
- Swin-T:\(C=96 \quad layer \, numbers={2,2,6,2}\);
- Swin-S:\(C=96 \quad layer \, numbers={2,2,18,2}\);
- Swin-B:\(C=128 \,\, layer \, numbers={2,2,18,2}\);
- Swin-L:\(C=192 \,\, layer \, numbers={2,2,18,2}\);
這里的最基礎模型是Swin-B模型,它和ViT-B/DeiT-B模型的計算復雜度一樣。Swin-T, Swin-S and Swin-L分別是繼承模型size的\(0.25\times,0.5\times,2\times\)的放縮。要注意的是Swin-T, Swin-S的計算復雜度分別與ResNet-50 (DeiT-S) and ResNet-101相當。所有實驗的配置中,默認窗口的大小設置為7;每個頭的查詢矩陣維度\(d=32\),且MLP擴張層的\(\alpha=4\)。其中\(C\)是第一階段中隱藏層的通道數。ImageNet圖像分類模型變量的模型大小、理論計算復雜度(FLOPs)和吞吐量如表1所示。
4、構建模型接口
接口: Swin-Transformer.models.build,build_model
這里主要對要構建的模型進行識別,Swin項目當然只支持Swin項目,如果有其他配置需要添加,可以在else部分修改。
主要的代碼為:
model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.SWIN.PATCH_SIZE,
in_chans=config.MODEL.SWIN.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.SWIN.EMBED_DIM,
depths=config.MODEL.SWIN.DEPTHS,
num_heads=config.MODEL.SWIN.NUM_HEADS,
window_size=config.MODEL.SWIN.WINDOW_SIZE,
mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
qkv_bias=config.MODEL.SWIN.QKV_BIAS,
qk_scale=config.MODEL.SWIN.QK_SCALE,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.SWIN.APE,
patch_norm=config.MODEL.SWIN.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT)
Note: 這里作者使用yacs.config進行配置,相關內容可以參考Swin-Transformer.config.py。
SwinTransformer類的參數幾乎是見名知義,關於參數的具體含義為:
"""
Args:
img_size (int | tuple(int)): 輸入圖像大小. Default 224
patch_size (int | tuple(int)): 像素塊的大小. Default: 4
in_chans (int): 輸入圖像的通道數. Default: 3
num_classes (int): 分類數. Default: 1000
embed_dim (int): 像素塊編碼的維度. Default: 96
depths (tuple(int)): 每個Swin Transformer層的深度.
num_heads (tuple(int)): 不同層的attention頭數量.
window_size (int): shfited window的大小,即M. Default: 7
mlp_ratio (float): MLP hidden dim到embedding dim的比率. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): 是否添加絕對位置編碼. Default: False
patch_norm (bool): 是否在patch embedding后添加normalization. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
5、SwinTransformer
SwinTransformer的初始化參數前面已經介紹過,下面我們梳理其模型構建的組件----屬性變量。
self.num_features
這個變量並沒有參數直接傳遞,其計算為:int(embed_dim * 2 ** (len(depths) - 1))。
self.ape、self.absolute_pos_embed
self.ape是否使用絕對位置編碼。self.absolute_pos_embed在初始化的時候,是一個shape為(1, num_patches, embed_dim)的全零張量,並使用trunc_normal_進行截斷正態分布初始化,其標准差為0.02。
將PatchEmbed與絕對位置編碼相加 [絕對位置編碼為可選] 之后,再對合並后的特征圖進行隨機dropout。再根據self.layers生成每一個stage的前向傳播。stage的構建是基於BasicLayer生成!可參考5.2。
SwinTransformer的主要由 PatchEmbed + layer1 + layer2 + layer3 + layer4 構成。
分類模型這里主要接:LN + AvgPool1d + Head: nn.Linear
5.1 PatchEmbed
self.patch_embed
將圖像分割為不重疊的像素塊。
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None
)
PatchEmbed是基於torch實現的層。
參數:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
其他屬性:
-
self.patches_resolution:像素塊分辨率,即高寬對應的像素塊分割數
[img_size[0] // patch_size[0], img_size[1] // patch_size[1]] -
self.num_patches:像素塊的數量,即 patches_resolution[0] * patches_resolution[1]
-
self.proj:這里使用卷積進行編碼,卷積核的大小就是patch_size,步長也是patch_size。所以這個卷積處理后patch的shape變為\(1\times1\times embed\_dim\)。
前向傳播:
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
前向傳播首先驗證輸入特征圖的shape是否滿足初始化中的image配置self.img_size。然后進行patch上的卷積,再進行Normalization layer標准化[可選]。如圖所示:
總結PatchEmbed:
像素塊編碼主要是使用卷積在原features上進行卷積,然后再接一個可選的BN層。卷積的卷積核大小就是像素塊的大小即 patch_size,步長也為 patch_size。這樣實現了不同像素塊之間不會有信息融合,也即只對像素塊進行編碼,這與傳統的CNN滑窗相比,減少了大量的卷積操作!卷積后的features如圖所示,在dim=2和dim=3的維度上將張量平鋪,在將最后兩個維度轉置,得到patch卷積輸出,其shape為:(batch_size, ph*pw, channels)。BN層就是pytorch中的BN層。
5.2 BasicLayer簡介
BasicLayer從圖像上很好理解,它是論文中stage1~stage4的生成單元。
build layers:
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
BasicLayer參數:
"""
dim (int): 輸入的通道數.
input_resolution (tuple[int]): 輸入特征圖的分辨率.
depth (int): blocks數量,即當前stage的深度.
num_heads (int): attention頭的數量.
window_size (int): shifted局部窗口的大小.
mlp_ratio (float): mlp 隱含層的維度到embedding dim的比率.
qkv_bias (bool, optional): 給query, key, value添加一個可學習的偏置. 默認: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. 默認: 0.0
attn_drop (float, optional): Attention dropout rate. 默認: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. 默認: 0.0
norm_layer (nn.Module, optional): Normalization layer. 默認: nn.LayerNorm
downsample (nn.Module | None, optional): 在stage的最后進行下采樣的Downsample layer. 默認: None
use_checkpoint (bool): 是否使用checkpoint對當前層進行保存. 默認: False.
"""
每一個stage都是由depth個SwinTransformerBlock組成,就如何殘差神經網絡中的殘差塊一樣。在前向傳播中,BasicLayer非常簡單,就是使用SwinTransformerBlock構建基本框架再加一個可選的下采樣層。
其中SwinTransformerBlock的引入為:
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
這里的self.blocks主要包含了當前stage的主要SwinTransformerBlock塊。
下采樣層主要是做像素的融合。
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
SwinTransformerBlock參考5.3。
5.3 SwinTransformerBlock
參數:
"""
dim (int): 輸入的通道數.
input_resolution (tuple[int]): 輸入特征圖的分辨率.
num_heads (int): attention頭的數量.
window_size (int): 局部窗口的大小.
shift_size (int): SW-MSA的移動size.
mlp_ratio (float): mlp 隱含層的維度到embedding dim的比率.
qkv_bias (bool, optional): 給query, key, value添加一個可學習的偏置. 默認: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. 默認: 0.0
attn_drop (float, optional): Attention dropout rate. 默認: 0.0
drop_path (float, optional): Stochastic depth rate. 默認: 0.0
act_layer (nn.Module, optional): 激活層. 默認: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. 默認: nn.LayerNorm
"""
attn_mask
當shift_size為0時,attn_mask為None。當它不為0時,那么那么由於窗口的移動,會讓原來的特征圖被划分為9個區域。這9個區域的划分原則是:新窗口來源是否是移動后合並產生的。以\(input\_resolution=(8,8)\)為例,有如下區域划分:
# 生成全零張量
img_mask = torch.zeros((1, H, W, 1))
# 按區域划分mask
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
此時的img_mask.squeeze(dim=3)為:
tensor([[[0., 0., 0., 0., 1., 1., 2., 2.],
[0., 0., 0., 0., 1., 1., 2., 2.],
[0., 0., 0., 0., 1., 1., 2., 2.],
[0., 0., 0., 0., 1., 1., 2., 2.],
[3., 3., 3., 3., 4., 4., 5., 5.],
[3., 3., 3., 3., 4., 4., 5., 5.],
[6., 6., 6., 6., 7., 7., 8., 8.],
[6., 6., 6., 6., 7., 7., 8., 8.]]])
然后我們可以獲取新生成的windows的mask:
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
新生成的每個窗口的mask為:
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
這里的attn_mask會傳給WindowAttention用於窗口內的多頭注意力計算。實際就是在WindowAttention中的softmax之前將添加偏置的\(QK^{T} / \sqrt{d} + B\)再加一個mask信息。如最后依據所示,不等於0的那些點全部將mask值置為\(-100\)。這樣實現了對移動拼接產生的window注意力輸出產生一個偏置。
前向傳播:
-
第一步:檢測定義的輸入分辨率是否與輸入的特征圖x的L(序列長度)相同;
-
第二步:使用self.norm1進行特征標准化,再將數據view(B, H, W, C);
-
第三步:使用torch.roll移動特征圖;
-
第四步:使用window_partition划分窗口,這里是在shifted_x上面划分,得到(num_windows*B, window_size, window_size, C)的特征圖,再view(-1, self.window_size * self.window_size, C)。
-
第五步:實現W-MSA/SW-MSA 結構
# num_windows*B, window_size*window_size, C attn_windows = self.attn(x_windows, mask=self.attn_mask)其中 x_windows 是 shifted_x 的窗口划分,self.attn 是WindowAttention的的實例。W-MSA/SW-MSA 的實現區別主要為是否使用shifted。
SwinTransformerBlock主要就是W-MSA/SW-MSA的實現,其結構為:\(LN + (W-MSA/SW-MSA) + LN + MLP\)。要注意的是shifted的特征圖最后我們會還原。這里的LN為nn.LayerNorm;MLP為作者自己的實現。
MLP:
MLP是:全連接層 + 激活層 + Dropout層 + 全連接層 + Dropout層
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
5.4 WindowAttention
基於帶有相對位置偏置的多頭注意力模型的移動/非移動窗口注意力模型。
參數:
"""
dim (int): 輸入通道數.
window_size (tuple[int]): 局部窗口的大小.
num_heads (int): attention頭的數量.
qkv_bias (bool, optional): 給query, key, value添加一個可學習的偏置. 默認: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. 默認: 0.0
proj_drop (float, optional): Dropout ratio of output. 默認: 0.0
"""
窗口注意力層的初始化:
相對位置偏置
# 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(self.relative_position_bias_table, std=.02)
self.relative_position_bias_table使用了截斷正太分布進行初始化,標准差為0.02。
如論文中所示,\(M\)標識窗口的大小,那么初始化的偏置矩陣是\(\hat{B} \in \mathbb{R}^{\left( 2M-1\right)\times \left( 2M-1\right)}\),為什么是\(\left( 2M-1\right)\times \left( 2M-1\right)\)?后面再說明這個問題!
WindowAttention層最主要的就是相對位置偏置的編碼部分比較復雜,其他操作都是我們熟悉的torch層,所以,這里仔細研究其處理過程。
相對位置偏置\(B\)是從\(\hat{B}\)的一個token。所以\(\hat{B}\)存儲了所有的偏置,\(B\)要通過索引獲取。下面是索引的生成:
coords:記錄了窗口的坐標,原點為窗口左上角;
coords_flatten:記錄了坐標的平鋪;sahpe為 ( 2, \(M^{2}\));
relative_coords:記錄了窗口內的像素(像素塊)的相對位置;如像素塊 \(patch_{a}\) 有\(M^{2}\)個相對位置,因為窗口內有\(M^{2}\)個像素塊。
作者在項目中的實現方式是:
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
因此,relative_coords的shape為 (\(M^{2}\),\(M^{2}\),2)。此時的relative_coords[0, :, :]標識的是(h, w)=(0, 0) 到所有點的相對坐標。注意此時窗口內的任意兩個坐標的相對位置我們都有了!
但是作者又進行了下面的操作:
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
前兩行做的事情是將相對左邊都移動到從0開始。第三行是將高H的相對坐標乘以了\(2M-1\)。
乘以\(2M-1\)是何用意?self.relative_position_bias_table是初始化的偏置表格,我們需要使用索引進行獲取,而relative_coords是生成索引的關鍵,生成索引的代碼為:
relative_position_index = relative_coords.sum(-1)
這里的索引本質上等同於偏置,如果索引相同則偏置也相同。首先我們探討relative_position_index應該有什么樣的性質:
-
當像素塊\(patch_{a}\)與像素塊\(patch_{a+1}\)同行(或者同列);像素塊\(patch_{b}\)與像素塊\(patch_{b+1}\)同行(或者同列)時。像素塊 \(patch_{a}\) 與像素塊 \(patch_{b}\) 之間的偏置應該和像素塊 \(patch_{a+1}\) 與像素塊 \(patch_{b+1}\) 之間的偏置一樣!即:
relative_position_index[i,j] = relative_position_index[M-1-j,M-1-i]注意上面的等式應該滿足絕大多數情況,但是在W的邊界上不應該滿足,因為我們的數據是行優先排列,索引+1即為下一個像素塊,如果當前是寬的邊界,那么下一個就換行了。因此我們的索引最好滿足\(j+1\)除以\(l\)余數不為0,其中\(l \in \left \{ M, 2M,\cdots ,M^{2} \right \}\)。
比較繁瑣的是寬的右邊界上,下一個換行了,但是這樣的模式下,偏置應為也一樣!如第\(3M-1\)像素塊相對於第0個像素塊的偏置 和 \(4M-1\)像素塊相對於第\(M\)個像素塊的偏置是一樣的!
-
那么問題來了,基於上面的准則,我們至少需要多少個偏置?由論文我們知道,需要\(\left (2M-1 \right )\times \left (2M-1 \right )\)個。這是怎么計算的呢?
首先,relative_position_index矩陣的shape為(\(M^{2}\),\(M^{2}\))。在主對角線方向上,我們共有\(2M^{2}-1\)條線,每條線都只有一種或兩種偏置索引,原因參考上面的規則說明。那到底哪些是1,哪些是2呢?推導可以發現,每條線組成偏置索引數量的序列為:
這個序列2的個數為:\(2\left(M-1\right) \times \left(M-1\right)\);1的數量為:\(2\left(M-1\right) + 2M - 1\)。
所以我們需要的索引數量為:
\[\begin{align} 2 \times 2&\left(M-1\right) \times \left(M-1\right) + 2\left(M-1\right) + 2M - 1\\ &= 4 \left( M^{2} - 2M + 1 \right) + 4M - 3 \\ &=4M^{2} - 8M + 4 + 4M - 3 \\ &=4M^{2} -4M + 1\\ &=\left( 2M - 1 \right)^{2} \end{align} \]實際上到這里我們大概就知道乘以\(2M-1\),就是為了讓索引滿足上述需求,且索引最小值到最大值是連續的!小於\(2M-1\)時,索引矩陣就不能滿足上面的規則;大於\(2M-1\)時,索引矩陣的值就不是連續的!那么為什么是\(2M-1\)?
解釋:
-
(1)在高H對應的特征圖上,每個\(M \times M\)的塊是一樣的,且,主對角線方向上是一樣的,這樣就會產生\(2M-1\)個不同\(M \times M\)的縱坐標H的索引塊;
-
(2)上訴高的每一個\(M \times M\)塊對應的寬的索引塊是一樣的;
-
(3)一個\(M \times M\)的寬索引塊,其寬的索引取值范圍是\(\left[ 0, 2M-2 \right]\);
-
(4)對於索引我們要使用:\(H \times x + W\)的形式獲取最終的相對位置索引,那么對於每一行我們乘以了\(x\),我們仍然需要保持其相鄰的\(M \times M\)塊之間的大小關系,對於高H,相鄰的高索引都是相差1,假設當前塊的行索引為\(m\),那么有:
\[mx+0> \left ( m-1 \right )x + 2M-2 \Rightarrow x > 2M-2 \] -
(5)對於relative_coords的左下角元素,合並高索引與寬索引后為:
\[\left ( 2M-2 \right )x + 2M-2 \leq \left ( 2M-1 \right )^{2} \Rightarrow x \leq 2M-1 \] -
(6)由不等式(6)、(7)可以得出乘的數只能是\(2M-1\)。
-
至此我們得到相對位置偏置的索引了,比如M=4,我們可以得到如下的索引:
tensor([[24, 23, 22, 21, 17, 16, 15, 14, 10, 9, 8, 7, 3, 2, 1, 0],
[25, 24, 23, 22, 18, 17, 16, 15, 11, 10, 9, 8, 4, 3, 2, 1],
[26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10, 9, 5, 4, 3, 2],
[27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10, 6, 5, 4, 3],
[31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14, 10, 9, 8, 7],
[32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15, 11, 10, 9, 8],
[33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10, 9],
[34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10],
[38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14],
[39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15],
[40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16],
[41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17],
[45, 44, 43, 42, 38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21],
[46, 45, 44, 43, 39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22],
[47, 46, 45, 44, 40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23],
[48, 47, 46, 45, 41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24]])
其他初始化
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
前向傳播:
-
首先獲取q, k, v。
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)注意這里是使用線性變換將維度擴大到3倍,使其與q, k, v對應。
生成q, k, v的代碼塊:
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2]注意這里的x的shape為 (num_windows*B, N, C) ,而上面的3可以理解為將新生成的通道分為q, k, v三份,再將每一份的通道數C拆為:self.num_heads 與 C // self.num_heads 兩個維度,以實現多頭機制。
-
計算注意力:\(QK^{T} / \sqrt{d}\)
q = q * self.scale attn = (q @ k.transpose(-2, -1)) -
給注意力添加偏置
relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) -
實現多頭注意力
\[\mathbf{Attention} \left(Q,K,V \right) = \mathbf{SoftMax} \left(QK^{T} / \sqrt{d} + B \right)V \]
完!
