SwinUNet2022


1. 概述

本文提出了一種以\(Swin\)變壓器層為基本塊的\(SUNet\)恢復模型,並將其應用於\(UNet\)架構中進行圖像去噪。

2. 背景

圖像恢復是一種重要的低級圖像處理方法,可以提高其在目標檢測、圖像分割和圖像分類等高級視覺任務中的性能。在一般的恢復任務中,一個被損壞的圖像Y可以表示為:

\[Y=D(X)+n \tag 1 \]

其中\(X\)是一個干凈的圖像,\(D(\cdot)\)表示退化函數,\(n\)表示加性噪聲。一些常見的恢復任務是去噪、去模糊和去阻塞。

2.1 CNN局限性

雖然大多數基於卷積神經網絡(CNN)的方法都取得了良好的性能,但卷積層存在幾個問題。首先,卷積核與圖像的內容無關(無法與圖像內容相適應)。使用相同的卷積核來恢復不同的圖像區域可能不是最好的解決方案。其次,由於卷積核可以看作是一個小塊,其中獲取的特征是局部信息,換句話說,當我們進行長期依賴建模時,全局信息就會丟失。

3. 結構

3.1 UNet

目前,UNet由於具有層次特征映射來獲得豐富的多尺度上下文特征,是許多圖像處理應用中著名的架構。此外,它利用編碼器和解碼器之間的跳躍連接來增強圖像的重建過程。UNet被廣泛應用於許多計算機視覺任務,如分割、恢復[。此外,它還有各種改進的版本,如Res-UNet,Dense-UNet,Attention-UNet[和Non-local-UNet。由於具有較強的自適應骨干網,UNet可以很容易地應用於不同的提取塊,以提高性能。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-MpLyfXEs-1647427670075)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316140830395.png)]

3.2 Swin Transformer

Transformer模型在自然語言處理(NLP)領域取得了成功,並具有良好的競爭性能,特別是在圖像分類方面。然而,直接使用Transformer到視覺任務的兩個主要問題是:

(1)圖像和序列之間的尺度差異很大。由於Transformer需要參數量為一維序列參數的平方倍,所以存在長序列建模的缺陷。

(2)Transformer不擅長解決實例分割等密集預測任務,即像素級任務。然而,Swin Transfomer通過滑動窗口解決了上述問題,降低了參數,並在許多像素級視覺任務中實現了最先進的性能。

3.3 SUNet

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-mK2P8qBG-1647427670077)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316141320770.png)]

所提出的Swin Transformer UNet(SUNet)的架構是基於圖像分割模型,如上圖所示。SUNet由三個模塊組成:

(1)淺層特征提取;

(2)UNet特征提取;

(3)重建模塊

淺層特征提取模塊:

對於有噪聲的輸入圖像\(Y∈R^{H×W×3}\),其中H,W為失真圖像的分辨率。我們使用單個3×3卷積層\(M_{SFE}(\cdot)\)獲取輸入圖像的顏色或紋理等低頻信息。淺特征\(F_{shallow}∈R^{H×W×C}\)可以表示為:

\[F_{shallow}=M_{SFE}(Y) \tag 2 \]

其中,C是淺層特征的通道數,在后一個實驗部分中,我們都設置為96.

UNet 特征提取網絡:

然后,將淺層特征\(F_{shallow}\)輸入UNet特征提取\(M_{UFE}(\cdot)\),UNet用來提取高級、多尺度深度特征\(F_{deep}∈R^{H×W×C}\)

\[F_{deep}=M_{UFE}(F_{shallow}) \tag 3 \]

其中,\(M_{UFE}(\cdot)\)是帶有Swin變壓器塊的UNet架構,它在單個塊中包含8個Swin Transformer層來代替卷積。Swin Transformer Block(STB)和Swin Transformer Layer(STL)將在下一小節中進行詳細說明。

重建層:

最后,我們仍然使用3×3卷積\(M_{R}(\cdot)\)從深度特征\(F_{deep}\)中生成無噪聲圖像\(\hat{X}∈R^{H×W×3}\),其公式為:

\[\hat{X}=M_{R}(F_{deep}) \tag 4 \]

注意,\(\hat{X}\)是以噪聲圖像\(Y\)作為SUNet的輸入得到的,其中\(X\)是(1)中Y圖像的原高分率圖像。

3.4 Loss function

我們優化了我們的SUNet端到端與規則的\(L1\)像素損失的圖像去噪:

\[L_{denoise}=||\hat{X}-X||_1 \tag 5 \]

3.5 Swin Transformer Block

在UNet提取模塊中,我們使用STB來代替傳統的卷積層,如下圖所示。STL是基於NLP中的原始Transformer Layer。STL的數量總是2的倍數,其中一個是window multi-head-self-attention(W-MSA),另一個是shifted-window multi-head self-attention(SW-MSA)。

STL的公式描述:

\[\hat{f}^L=W-MSA(LN(f^{L-1}))+f^{L-1} \\ f^L=MLP(LN(\hat{f}^L))+\hat{f}^L \\ \hat{f}^{L+1}=SW-MSA(LN(f^{L}))+f^{L} \\ f^{L+1}=MLP(LN(\hat{f}^{L+1}))+\hat{f}^{L+1} \tag 6 \]

其中,\(LN(\cdot)\)表示為層歸一化,\(MLP\)是多層感知器,它具有兩個完全連接的層,同時后面跟一個線性單位(GELU)激活函數。

3.6 Resizing module

由於UNet具有不同的特征圖尺度,因此調整大小的模塊(例如,下樣本和上樣本)是必要的。在我們的SUNet中,我們使用\(patch\ merging\),並提出\(dual\ up-sample\)分別作為下樣本和上樣本模塊。

3.6.1 patch merging

對於降采樣模塊,該文將每一組2×2相鄰斑塊的輸入特征連接起來,然后使用線性層獲得指定的輸出通道特征。我們也可以把這看作是做卷積操作的第一步,也就是展開輸入的特征映射。

3.6.2 Dual up-sample

對於上采樣,原始的Swin-UNet采用patch expanding方法,等價於上采樣模塊中的轉置卷積。然而,轉置卷積很容易面對塊效應。在這里,我們提出了一個新的模塊,稱為雙上樣本,它包括兩種現有的上樣本方法(即Bilinear和PixelShuffle),以防止棋盤式的artifacts。所提出的上采樣模塊的體系結構如下圖所示。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-joxathTz-1647427670077)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316145653720.png)]

4. 結果

評估指標:

為了進行定量比較,我們考慮了峰值信噪比(PSNR)和結構相似度(SSIM)指數度量。

訓練集:

采用DIV2K作為訓練集,一共有900張高清圖片。我們對每個訓練圖像隨機裁剪100個大小為\(256×256\)的斑塊,並對\(800\)張訓練圖像從\(σ=5\)\(σ=50\)\(patch\)中隨機添加AWGN噪聲。至於驗證集,我們直接使用包含100張圖像的測試集,並添加具有三種不同噪聲水平的AWGN,\(σ=10、σ=30和σ=50\)

測試集:

對於評估,我們選擇了CBSD68數據集,它有68張彩色圖像,分辨率為768×512,以及Kodak24張數據集,由24張圖像組成,圖像大小為321×481。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-WdJUyCFm-1647427670078)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316170249843.png)]

在表1中,我們對去噪圖像進行了客觀的質量評價,並觀察到以下三件事:

(1)該文的SUNet具有競爭性的SSIM值,因為Swin-Transformer是基於全局信息(q,k,v可以提取全局信息),使得去噪圖像擁有更多的視覺效果。

(2)與基於unet的方法(DHDN、RDUNet)相比,該文所提出的SUNet模型中參數(↓60%)和FLOPs(↓3%)較少,在PSNR和SSIM上仍保持良好的得分

(3)與基於cnn的方法(DnCNN,IrCNN,FFDNet)相比,該文得到了其中最好的PSNR和SSIM結果,以及幾乎相同的FLOPs。雖然該文的模型的參數最多(99M),但它是由於自注意操作不能共享核的權值造成的。

4. 總結

  • 提出了一種基於圖像分割的雙unet模型的雙變換網絡進行圖像去噪。
  • 該文提出了一種雙上樣本塊結構,它包括亞像素方法和雙線性上樣本方法,以防止棋盤偽影。實驗結果表明,該方法優於轉置卷積的原始上樣本。
  • 該文的模型是第一個結合Swin變壓器和UNet進行去噪的模型。

Reference: Swin-unet: Unet-like pure transformer for medical image segmentation

5. 某些代碼的理解

5.1 window attention中的相對位置編碼

代碼位置位於: ./model/SUNet_detail.py 中的89行左右

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-KhTSbrYb-1647503570456)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316205927221.png)]

我這里展示了一個例子:

>>> import torch
>>> coords_h=torch.arange(3)
>>> coords_w=torch.arange(3)
>>> coords=torch.stack(torch.meshgrid([coords_h,coords_w]))
>>> coords
tensor([[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]])
>>> coords_flatten=torch.flatten(coords,1)
>>> coords_flatten
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])
>>> relative_coords=coords_flatten[:,:,None]-coords_flatten[:,None,:]
>>> relative_coords
tensor([[[ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0]],

        [[ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0]]])

可以看到\(relative\_coords\)的第一維是2,分別對應x軸和y軸方向(或者高,寬的方向)。剩下兩維呢是9*9。SUNet中間使用了一些SwinIR的結構,在SwinIR中是有shift-window的,在這里,我設置的window size為3。又因為再做attention的時候,我們把每一個window中的像素點當作一個token,那么最終的attention map(\(q * v\))的最后兩維就是\(window\_width \cdot window\_height\)

下面再來看具體的物理意義,\(3 \times 3\)的window一共有9個數值,第一維度分別代表這兩個軸;第一個矩陣中,第一行分別代表着第一個數值(一共有9個)在某個軸上相對於其他位置的距離(在第一個數值的右邊為負,左邊為正),第二個矩陣類似。不同window的相對位置是一樣的。

>>> relative_coords
tensor([[[ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0]],

        [[ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0]]])
>>> relative_coords = relative_coords.permute(1, 2, 0).contiguous()
>>> relative_coords
tensor([[[ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-2,  0],
         [-2, -1],
         [-2, -2]],

        [[ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-2,  1],
         [-2,  0],
         [-2, -1]],

        [[ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-2,  2],
         [-2,  1],
         [-2,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2]],

        [[ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1]],

        [[ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
         [-1,  1],
         [-1,  0]],

        [[ 2,  0],
         [ 2, -1],
         [ 2, -2],
         [ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2]],

        [[ 2,  1],
         [ 2,  0],
         [ 2, -1],
         [ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1]],

        [[ 2,  2],
         [ 2,  1],
         [ 2,  0],
         [ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0]]])

下面的代碼就是將相對位置坐標全部加上\(window\_size-1\),使得全部為正值:

>>> relative_coords[:, :, 0] += window_size[0] - 1
>>> relative_coords[:, :, 1] += window_size[1] - 1
>>> relative_coords
tensor([[[2, 2],
         [2, 1],
         [2, 0],
         [1, 2],
         [1, 1],
         [1, 0],
         [0, 2],
         [0, 1],
         [0, 0]],

        [[2, 3],
         [2, 2],
         [2, 1],
         [1, 3],
         [1, 2],
         [1, 1],
         [0, 3],
         [0, 2],
         [0, 1]],

        [[2, 4],
         [2, 3],
         [2, 2],
         [1, 4],
         [1, 3],
         [1, 2],
         [0, 4],
         [0, 3],
         [0, 2]],

        [[3, 2],
         [3, 1],
         [3, 0],
         [2, 2],
         [2, 1],
         [2, 0],
         [1, 2],
         [1, 1],
         [1, 0]],

        [[3, 3],
         [3, 2],
         [3, 1],
         [2, 3],
         [2, 2],
         [2, 1],
         [1, 3],
         [1, 2],
         [1, 1]],

        [[3, 4],
         [3, 3],
         [3, 2],
         [2, 4],
         [2, 3],
         [2, 2],
         [1, 4],
         [1, 3],
         [1, 2]],

        [[4, 2],
         [4, 1],
         [4, 0],
         [3, 2],
         [3, 1],
         [3, 0],
         [2, 2],
         [2, 1],
         [2, 0]],

        [[4, 3],
         [4, 2],
         [4, 1],
         [3, 3],
         [3, 2],
         [3, 1],
         [2, 3],
         [2, 2],
         [2, 1]],

        [[4, 4],
         [4, 3],
         [4, 2],
         [3, 4],
         [3, 3],
         [3, 2],
         [2, 4],
         [2, 3],
         [2, 2]]])

下面是將橫縱坐標的相對位置加起來:

>>> relative_position_index = relative_coords.sum(-1)
>>> relative_position_index
tensor([[4, 3, 2, 3, 2, 1, 2, 1, 0],
        [5, 4, 3, 4, 3, 2, 3, 2, 1],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [5, 4, 3, 4, 3, 2, 3, 2, 1],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [7, 6, 5, 6, 5, 4, 5, 4, 3],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [7, 6, 5, 6, 5, 4, 5, 4, 3],
        [8, 7, 6, 7, 6, 5, 6, 5, 4]])

下面為定義bias,隨機初始化,但是在網絡的迭代訓練中,是會被反向傳播的

 # define a parameter table of relative position bias
 self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] 			- 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
 trunc_normal_(self.relative_position_bias_table, std=.02)

一個window里面明明只有9個數值,為什么定義bias時,矩陣的維度為\(25, num_heads\)呢?**

這是因為上面我們加了\(window\_size -1\):不加之前最大值為\(window\_size-1\),后面在加上\(window\_size-1\),此時,最大值為\(2\times window\_size -2\),再算上零,一共有\(2\times window\_size-1\)。所以再初始化bias的時候,我覺得維度為\(2\times window\_size-1\)

就夠了,不知道為什么要定義\((2\times window\_size-1)\times 2\times window\_size-1\)呢?

每個window是獨立attention的,所以每個window的relative_position_bias都是一樣的。

下面就是加MASK操作了,不再贅述。

5.2 Shifted-window-attention

代碼位於\(SwinTransformerBlock\)中。

\(shift\_size>0\)時:

# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1
# nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size)  
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, 			 	float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
>>> img_mask = torch.zeros((1, 3, 3, 1))
>>> h_slices=(slice(0,-3),slice(-3,1),slice(-1,None))
>>> h_slices
(slice(0, -3, None), slice(-3, 1, None), slice(-1, None, None))
>>> w_slices=(slice(0,-3),slice(-3,-1),slice(-1,None))
>>> w_slices
(slice(0, -3, None), slice(-3, -1, None), slice(-1, None, None))
>>> cnt = 0
>>> for h in h_slices:
...     for w in w_slices:
...             img_mask[:,h,w,:]=cnt
...             cnt+=1
...
>>> img_mask
tensor([[[[4.],
          [4.],
          [5.]],

         [[4.],
          [4.],
          [5.]],

         [[7.],
          [7.],
          [8.]]]])
# nW, window_size, window_size, 1    [1,3,3,1]
mask_windows = window_partition(img_mask, self.window_size)
# 經過window_partition后,沒有發生變化
>>> mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
tensor([[4., 4., 5., 4., 4., 5., 7., 7., 8.]])
>>> attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
>>> attn_mask.size()
torch.Size([1, 9, 9])
>>> attn_mask
tensor([[[ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [-1., -1.,  0., -1., -1.,  0.,  2.,  2.,  3.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [-1., -1.,  0., -1., -1.,  0.,  2.,  2.,  3.],
         [-3., -3., -2., -3., -3., -2.,  0.,  0.,  1.],
         [-3., -3., -2., -3., -3., -2.,  0.,  0.,  1.],
         [-4., -4., -3., -4., -4., -3., -1., -1.,  0.]]])
>>> attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
>>> attn_mask
tensor([[[   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100., -100., -100.,    0.]]])

我們再來看forward函數對x的shift操作:

>>> x=torch.arange(0,9)
>>> x
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> x=x.unsqueeze(0)+x.unsqueeze(1)
>>> x
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9],
        [ 2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 4,  5,  6,  7,  8,  9, 10, 11, 12],
        [ 5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 7,  8,  9, 10, 11, 12, 13, 14, 15],
        [ 8,  9, 10, 11, 12, 13, 14, 15, 16]])
>>> x=x.unsqueeze(2).unsqueeze(0)
>>> x.size()
torch.Size([1, 9, 9, 1])
>>> shifted_x = torch.roll(x, shifts=(-1, -1), dims=(1, 2))
>>> xx=shifted_x.squeeze(3).squeeze(0)
>>> xx
tensor([[ 2,  3,  4,  5,  6,  7,  8,  9,  1],
        [ 3,  4,  5,  6,  7,  8,  9, 10,  2],
        [ 4,  5,  6,  7,  8,  9, 10, 11,  3],
        [ 5,  6,  7,  8,  9, 10, 11, 12,  4],
        [ 6,  7,  8,  9, 10, 11, 12, 13,  5],
        [ 7,  8,  9, 10, 11, 12, 13, 14,  6],
        [ 8,  9, 10, 11, 12, 13, 14, 15,  7],
        [ 9, 10, 11, 12, 13, 14, 15, 16,  8],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  0]])

那為什么要加mask呢,它是由一個假設的,假設各個window之間不相關,各個window單獨做attention。其中我們以\(H=9,W=9, window\_size=3, shift\_size=1\)為例。

圖片參考:SWin Transformer

沒有進行shift的window划分圖:
在這里插入圖片描述

再forward中,是會對輸入的x進行shift操作的:
在這里插入圖片描述

shift后的操作:

在這里插入圖片描述

其中,上圖黑線代表原來的邊界。每個彩色框代表經過shift操作后的window划分,可以看到每個彩色框內部黑線位置是一樣的;黑線是window的邊界。

>>> attn_mask
tensor([[[   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100., -100., -100.,    0.]]])

我們舉個例子說明:

以第一個彩色框為例,第一行代表第一個元素是否可以看到對應位置的元素(0代表看得到,-100代表看不到)。當兩個像素點位於不同的window之中(黑線)就是看不到,就賦給一個負數,后面再做softmax,對應權重就會非常小。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-COD9Kl1R-1647503570459)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220317132729985.png)]

5.3 Dual Up Sample

使用PixelShuffle和bilinear合起來的特征作為輸出。

5.4 Absolute position embedding

# absolute position embedding
if self.ape:
    self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, 			                embed_dim)) 
    trunc_normal_(self.absolute_pos_embed, std=.02)

大概位置再SUNet的635行左右。

那為什么要加Absolute position embedding呢?

是因為再5.2 window attention中呢,是有一個relative position的,但是relative position作用范圍僅僅是在一個window里面,即每個window相同位置上的relative position都是一樣的。所以需要absolute position。

5.5 DownSampling

下采樣是通過Patch Embedding實現的。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM