論文鏈接:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
Introduction
目前Transformer應用到圖像領域主要有兩大挑戰:
- 視覺實體變化大,在不同場景下視覺Transformer性能未必很好
- 圖像分辨率高,像素點多,Transformer基於全局自注意力的計算導致計算量較大
提出了一種包含滑窗操作,具有層級設計的Swin Transformer
總體架構
注意W-MSA和SW-MSA是成對使用的

整個模型采取層次化的設計,一共包含4個Stage,每個stage都會縮小輸入特征圖的分辨率,像CNN一樣逐層擴大感受野。
W-MSA
Windows Multi-Head Self-Attention
每個方框都是一個窗口,每個窗口是固定有7×7個patch,但是patch的大小是不固定的,它會隨着patch merging的操作而發生變化。比如一開始patch大小是4x4,把周邊四個窗口的patch拼在一起,從而得到了8x8的patch。

經過這一系列的操作之后,patch的數目在變少,最后整張圖只有一個窗口,7個patch。所以我們可以認為降采樣是指讓patch的數量減少,但是patch的大小在變大。

CNN在每個窗口做的是卷積的計算,每個窗口最后得到一個值,這個值代表着這個窗口的特征。而swin transformer在每個窗口做的是self-attention的計算,得到的是一個更新過的窗口,然后通過patch merging的操作,把窗口做了個合並,再繼續對這個合並后的窗口做self-attention的計算。
每個窗口內計算self-attention可以減小計算量,但是缺點是窗口之間無法進行信息交互,也就是說每個窗口的感受野變小,所以文章提出了shift window attention
SW-MSA
Shifted Windows Multi-Head Self-Attention
W-MSA和SW-MSA是成對使用的,那么第L+1層使用的就是SW-MSA(右側圖)。根據左右兩幅圖對比能夠發現窗口(Windows)發生了偏移(可以理解成窗口從左上角分別向右側和下方各偏移了\(\left \lfloor \frac {M} {2} \right \rfloor\)個像素)。
偏移后的窗口中,比如對於第一行第2列的2x4的窗口,它能夠使上一層的第一排的兩個窗口信息進行交流。再比如,第二行第二列的4x4的窗口,他能夠使上一層的四個窗口信息進行交流


但對窗口進行偏移之后,窗口的數量又增多了(從4個變成9個),這樣計算量又大了。
接下來來到本文的最亮點,通過設置合理的mask,讓Shifted Window Attention在與Window Attention相同的窗口個數下,達到等價的計算結果。
首先我們對Shift Window后的每個窗口都給上index,並且做一個roll操作(window_size=2, shift_size=-1)

但是把不同的區域合並在一起(比如5和3)進行MSA,這信息不就亂竄了嗎?
是的,為了防止這個問題,在實際計算中使用的是masked MSA即帶蒙板mask的MSA,這樣就能夠通過設置蒙板來隔絕不同區域的信息了,這個mask的計算方法是將矩陣乘積后index不一致的地方暴力減去100,softmax后就會忽略掉對應的值。
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)

Relative Position Bias
B就是bias,相對位置索引怎么求的見博客
總結
- 先對特征圖進行LayerNorm
- 通過self.shift_size決定是否需要對特征圖進行shift
- 然后將特征圖切成一個個窗口
- 計算Attention,通過self.attn_mask來區分Window Attention還是Shift Window Attention
- 將各個窗口做merging
- 如果之前有做shift操作,此時進行reverse shift,把之前的shift操作恢復
- 做dropout和殘差連接
- 再通過一層LayerNorm+全連接層,以及dropout和殘差連接