1. 概述
本文提出了一種以\(Swin\)變壓器層為基本塊的\(SUNet\)恢復模型,並將其應用於\(UNet\)架構中進行圖像去噪。
2. 背景
圖像恢復是一種重要的低級圖像處理方法,可以提高其在目標檢測、圖像分割和圖像分類等高級視覺任務中的性能。在一般的恢復任務中,一個被損壞的圖像Y可以表示為:
其中\(X\)是一個干凈的圖像,\(D(\cdot)\)表示退化函數,\(n\)表示加性噪聲。一些常見的恢復任務是去噪、去模糊和去阻塞。
2.1 CNN局限性
雖然大多數基於卷積神經網絡(CNN)的方法都取得了良好的性能,但卷積層存在幾個問題。首先,卷積核與圖像的內容無關(無法與圖像內容相適應)。使用相同的卷積核來恢復不同的圖像區域可能不是最好的解決方案。其次,由於卷積核可以看作是一個小塊,其中獲取的特征是局部信息,換句話說,當我們進行長期依賴建模時,全局信息就會丟失。
3. 結構
3.1 UNet
目前,UNet由於具有層次特征映射來獲得豐富的多尺度上下文特征,是許多圖像處理應用中著名的架構。此外,它利用編碼器和解碼器之間的跳躍連接來增強圖像的重建過程。UNet被廣泛應用於許多計算機視覺任務,如分割、恢復[。此外,它還有各種改進的版本,如Res-UNet,Dense-UNet,Attention-UNet[和Non-local-UNet。由於具有較強的自適應骨干網,UNet可以很容易地應用於不同的提取塊,以提高性能。
3.2 Swin Transformer
Transformer模型在自然語言處理(NLP)領域取得了成功,並具有良好的競爭性能,特別是在圖像分類方面。然而,直接使用Transformer到視覺任務的兩個主要問題是:
(1)圖像和序列之間的尺度差異很大。由於Transformer需要參數量為一維序列參數的平方倍,所以存在長序列建模的缺陷。
(2)Transformer不擅長解決實例分割等密集預測任務,即像素級任務。然而,Swin Transfomer通過滑動窗口解決了上述問題,降低了參數,並在許多像素級視覺任務中實現了最先進的性能。
3.3 SUNet
所提出的Swin Transformer UNet(SUNet)的架構是基於圖像分割模型,如上圖所示。SUNet由三個模塊組成:
(1)淺層特征提取;
(2)UNet特征提取;
(3)重建模塊
淺層特征提取模塊:
對於有噪聲的輸入圖像\(Y∈R^{H×W×3}\),其中H,W為失真圖像的分辨率。我們使用單個3×3卷積層\(M_{SFE}(\cdot)\)獲取輸入圖像的顏色或紋理等低頻信息。淺特征\(F_{shallow}∈R^{H×W×C}\)可以表示為:
其中,C是淺層特征的通道數,在后一個實驗部分中,我們都設置為96.
UNet 特征提取網絡:
然后,將淺層特征\(F_{shallow}\)輸入UNet特征提取\(M_{UFE}(\cdot)\),UNet用來提取高級、多尺度深度特征\(F_{deep}∈R^{H×W×C}\):
其中,\(M_{UFE}(\cdot)\)是帶有Swin變壓器塊的UNet架構,它在單個塊中包含8個Swin Transformer層來代替卷積。Swin Transformer Block(STB)和Swin Transformer Layer(STL)將在下一小節中進行詳細說明。
重建層:
最后,我們仍然使用3×3卷積\(M_{R}(\cdot)\)從深度特征\(F_{deep}\)中生成無噪聲圖像\(\hat{X}∈R^{H×W×3}\),其公式為:
注意,\(\hat{X}\)是以噪聲圖像\(Y\)作為SUNet的輸入得到的,其中\(X\)是(1)中Y圖像的原高分率圖像。
3.4 Loss function
我們優化了我們的SUNet端到端與規則的\(L1\)像素損失的圖像去噪:
3.5 Swin Transformer Block
在UNet提取模塊中,我們使用STB來代替傳統的卷積層,如下圖所示。STL是基於NLP中的原始Transformer Layer。STL的數量總是2的倍數,其中一個是window multi-head-self-attention(W-MSA),另一個是shifted-window multi-head self-attention(SW-MSA)。
STL的公式描述:
其中,\(LN(\cdot)\)表示為層歸一化,\(MLP\)是多層感知器,它具有兩個完全連接的層,同時后面跟一個線性單位(GELU)激活函數。
3.6 Resizing module
由於UNet具有不同的特征圖尺度,因此調整大小的模塊(例如,下樣本和上樣本)是必要的。在我們的SUNet中,我們使用\(patch\ merging\),並提出\(dual\ up-sample\)分別作為下樣本和上樣本模塊。
3.6.1 patch merging
對於降采樣模塊,該文將每一組2×2相鄰斑塊的輸入特征連接起來,然后使用線性層獲得指定的輸出通道特征。我們也可以把這看作是做卷積操作的第一步,也就是展開輸入的特征映射。
3.6.2 Dual up-sample
對於上采樣,原始的Swin-UNet采用patch expanding方法,等價於上采樣模塊中的轉置卷積。然而,轉置卷積很容易面對塊效應。在這里,我們提出了一個新的模塊,稱為雙上樣本,它包括兩種現有的上樣本方法(即Bilinear和PixelShuffle),以防止棋盤式的artifacts。所提出的上采樣模塊的體系結構如下圖所示。
4. 結果
評估指標:
為了進行定量比較,我們考慮了峰值信噪比(PSNR)和結構相似度(SSIM)指數度量。
訓練集:
采用DIV2K作為訓練集,一共有900張高清圖片。我們對每個訓練圖像隨機裁剪100個大小為\(256×256\)的斑塊,並對\(800\)張訓練圖像從\(σ=5\)到\(σ=50\)的\(patch\)中隨機添加AWGN噪聲。至於驗證集,我們直接使用包含100張圖像的測試集,並添加具有三種不同噪聲水平的AWGN,\(σ=10、σ=30和σ=50\)。
測試集:
對於評估,我們選擇了CBSD68數據集,它有68張彩色圖像,分辨率為768×512,以及Kodak24張數據集,由24張圖像組成,圖像大小為321×481。
在表1中,我們對去噪圖像進行了客觀的質量評價,並觀察到以下三件事:
(1)該文的SUNet具有競爭性的SSIM值,因為Swin-Transformer是基於全局信息(q,k,v可以提取全局信息),使得去噪圖像擁有更多的視覺效果。
(2)與基於unet的方法(DHDN、RDUNet)相比,該文所提出的SUNet模型中參數(↓60%)和FLOPs(↓3%)較少,在PSNR和SSIM上仍保持良好的得分
(3)與基於cnn的方法(DnCNN,IrCNN,FFDNet)相比,該文得到了其中最好的PSNR和SSIM結果,以及幾乎相同的FLOPs。雖然該文的模型的參數最多(99M),但它是由於自注意操作不能共享核的權值造成的。
4. 總結
- 提出了一種基於圖像分割的雙unet模型的雙變換網絡進行圖像去噪。
- 該文提出了一種雙上樣本塊結構,它包括亞像素方法和雙線性上樣本方法,以防止棋盤偽影。實驗結果表明,該方法優於轉置卷積的原始上樣本。
- 該文的模型是第一個結合Swin變壓器和UNet進行去噪的模型。
Reference: Swin-unet: Unet-like pure transformer for medical image segmentation
5. 某些代碼的理解
5.1 window attention中的相對位置編碼
代碼位置位於: ./model/SUNet_detail.py 中的89行左右
我這里展示了一個例子:
>>> import torch
>>> coords_h=torch.arange(3)
>>> coords_w=torch.arange(3)
>>> coords=torch.stack(torch.meshgrid([coords_h,coords_w]))
>>> coords
tensor([[[0, 0, 0],
[1, 1, 1],
[2, 2, 2]],
[[0, 1, 2],
[0, 1, 2],
[0, 1, 2]]])
>>> coords_flatten=torch.flatten(coords,1)
>>> coords_flatten
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
[0, 1, 2, 0, 1, 2, 0, 1, 2]])
>>> relative_coords=coords_flatten[:,:,None]-coords_flatten[:,None,:]
>>> relative_coords
tensor([[[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0]],
[[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0],
[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0],
[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0]]])
可以看到\(relative\_coords\)的第一維是2,分別對應x軸和y軸方向(或者高,寬的方向)。剩下兩維呢是9*9。SUNet中間使用了一些SwinIR的結構,在SwinIR中是有shift-window的,在這里,我設置的window size為3。又因為再做attention的時候,我們把每一個window中的像素點當作一個token,那么最終的attention map(\(q * v\))的最后兩維就是\(window\_width \cdot window\_height\)。
下面再來看具體的物理意義,\(3 \times 3\)的window一共有9個數值,第一維度分別代表這兩個軸;第一個矩陣中,第一行分別代表着第一個數值(一共有9個)在某個軸上相對於其他位置的距離(在第一個數值的右邊為負,左邊為正),第二個矩陣類似。不同window的相對位置是一樣的。
>>> relative_coords
tensor([[[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0]],
[[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0],
[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0],
[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0]]])
>>> relative_coords = relative_coords.permute(1, 2, 0).contiguous()
>>> relative_coords
tensor([[[ 0, 0],
[ 0, -1],
[ 0, -2],
[-1, 0],
[-1, -1],
[-1, -2],
[-2, 0],
[-2, -1],
[-2, -2]],
[[ 0, 1],
[ 0, 0],
[ 0, -1],
[-1, 1],
[-1, 0],
[-1, -1],
[-2, 1],
[-2, 0],
[-2, -1]],
[[ 0, 2],
[ 0, 1],
[ 0, 0],
[-1, 2],
[-1, 1],
[-1, 0],
[-2, 2],
[-2, 1],
[-2, 0]],
[[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 0, 0],
[ 0, -1],
[ 0, -2],
[-1, 0],
[-1, -1],
[-1, -2]],
[[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[-1, 1],
[-1, 0],
[-1, -1]],
[[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 0, 2],
[ 0, 1],
[ 0, 0],
[-1, 2],
[-1, 1],
[-1, 0]],
[[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 0, 0],
[ 0, -1],
[ 0, -2]],
[[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 0, 1],
[ 0, 0],
[ 0, -1]],
[[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 0, 2],
[ 0, 1],
[ 0, 0]]])
下面的代碼就是將相對位置坐標全部加上\(window\_size-1\),使得全部為正值:
>>> relative_coords[:, :, 0] += window_size[0] - 1
>>> relative_coords[:, :, 1] += window_size[1] - 1
>>> relative_coords
tensor([[[2, 2],
[2, 1],
[2, 0],
[1, 2],
[1, 1],
[1, 0],
[0, 2],
[0, 1],
[0, 0]],
[[2, 3],
[2, 2],
[2, 1],
[1, 3],
[1, 2],
[1, 1],
[0, 3],
[0, 2],
[0, 1]],
[[2, 4],
[2, 3],
[2, 2],
[1, 4],
[1, 3],
[1, 2],
[0, 4],
[0, 3],
[0, 2]],
[[3, 2],
[3, 1],
[3, 0],
[2, 2],
[2, 1],
[2, 0],
[1, 2],
[1, 1],
[1, 0]],
[[3, 3],
[3, 2],
[3, 1],
[2, 3],
[2, 2],
[2, 1],
[1, 3],
[1, 2],
[1, 1]],
[[3, 4],
[3, 3],
[3, 2],
[2, 4],
[2, 3],
[2, 2],
[1, 4],
[1, 3],
[1, 2]],
[[4, 2],
[4, 1],
[4, 0],
[3, 2],
[3, 1],
[3, 0],
[2, 2],
[2, 1],
[2, 0]],
[[4, 3],
[4, 2],
[4, 1],
[3, 3],
[3, 2],
[3, 1],
[2, 3],
[2, 2],
[2, 1]],
[[4, 4],
[4, 3],
[4, 2],
[3, 4],
[3, 3],
[3, 2],
[2, 4],
[2, 3],
[2, 2]]])
下面是將橫縱坐標的相對位置加起來:
>>> relative_position_index = relative_coords.sum(-1)
>>> relative_position_index
tensor([[4, 3, 2, 3, 2, 1, 2, 1, 0],
[5, 4, 3, 4, 3, 2, 3, 2, 1],
[6, 5, 4, 5, 4, 3, 4, 3, 2],
[5, 4, 3, 4, 3, 2, 3, 2, 1],
[6, 5, 4, 5, 4, 3, 4, 3, 2],
[7, 6, 5, 6, 5, 4, 5, 4, 3],
[6, 5, 4, 5, 4, 3, 4, 3, 2],
[7, 6, 5, 6, 5, 4, 5, 4, 3],
[8, 7, 6, 7, 6, 5, 6, 5, 4]])
下面為定義bias,隨機初始化,但是在網絡的迭代訓練中,是會被反向傳播的:
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
trunc_normal_(self.relative_position_bias_table, std=.02)
一個window里面明明只有9個數值,為什么定義bias時,矩陣的維度為\(25, num_heads\)呢?**
這是因為上面我們加了\(window\_size -1\):不加之前最大值為\(window\_size-1\),后面在加上\(window\_size-1\),此時,最大值為\(2\times window\_size -2\),再算上零,一共有\(2\times window\_size-1\)。所以再初始化bias的時候,我覺得維度為\(2\times window\_size-1\)
就夠了,不知道為什么要定義\((2\times window\_size-1)\times 2\times window\_size-1\)呢?。
每個window是獨立attention的,所以每個window的relative_position_bias都是一樣的。
下面就是加MASK操作了,不再贅述。
5.2 Shifted-window-attention
代碼位於\(SwinTransformerBlock\)中。
當\(shift\_size>0\)時:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
>>> img_mask = torch.zeros((1, 3, 3, 1))
>>> h_slices=(slice(0,-3),slice(-3,1),slice(-1,None))
>>> h_slices
(slice(0, -3, None), slice(-3, 1, None), slice(-1, None, None))
>>> w_slices=(slice(0,-3),slice(-3,-1),slice(-1,None))
>>> w_slices
(slice(0, -3, None), slice(-3, -1, None), slice(-1, None, None))
>>> cnt = 0
>>> for h in h_slices:
... for w in w_slices:
... img_mask[:,h,w,:]=cnt
... cnt+=1
...
>>> img_mask
tensor([[[[4.],
[4.],
[5.]],
[[4.],
[4.],
[5.]],
[[7.],
[7.],
[8.]]]])
# nW, window_size, window_size, 1 [1,3,3,1]
mask_windows = window_partition(img_mask, self.window_size)
# 經過window_partition后,沒有發生變化
>>> mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
tensor([[4., 4., 5., 4., 4., 5., 7., 7., 8.]])
>>> attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
>>> attn_mask.size()
torch.Size([1, 9, 9])
>>> attn_mask
tensor([[[ 0., 0., 1., 0., 0., 1., 3., 3., 4.],
[ 0., 0., 1., 0., 0., 1., 3., 3., 4.],
[-1., -1., 0., -1., -1., 0., 2., 2., 3.],
[ 0., 0., 1., 0., 0., 1., 3., 3., 4.],
[ 0., 0., 1., 0., 0., 1., 3., 3., 4.],
[-1., -1., 0., -1., -1., 0., 2., 2., 3.],
[-3., -3., -2., -3., -3., -2., 0., 0., 1.],
[-3., -3., -2., -3., -3., -2., 0., 0., 1.],
[-4., -4., -3., -4., -4., -3., -1., -1., 0.]]])
>>> attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
>>> attn_mask
tensor([[[ 0., 0., -100., 0., 0., -100., -100., -100., -100.],
[ 0., 0., -100., 0., 0., -100., -100., -100., -100.],
[-100., -100., 0., -100., -100., 0., -100., -100., -100.],
[ 0., 0., -100., 0., 0., -100., -100., -100., -100.],
[ 0., 0., -100., 0., 0., -100., -100., -100., -100.],
[-100., -100., 0., -100., -100., 0., -100., -100., -100.],
[-100., -100., -100., -100., -100., -100., 0., 0., -100.],
[-100., -100., -100., -100., -100., -100., 0., 0., -100.],
[-100., -100., -100., -100., -100., -100., -100., -100., 0.]]])
我們再來看forward函數對x的shift操作:
>>> x=torch.arange(0,9)
>>> x
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> x=x.unsqueeze(0)+x.unsqueeze(1)
>>> x
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9],
[ 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 4, 5, 6, 7, 8, 9, 10, 11, 12],
[ 5, 6, 7, 8, 9, 10, 11, 12, 13],
[ 6, 7, 8, 9, 10, 11, 12, 13, 14],
[ 7, 8, 9, 10, 11, 12, 13, 14, 15],
[ 8, 9, 10, 11, 12, 13, 14, 15, 16]])
>>> x=x.unsqueeze(2).unsqueeze(0)
>>> x.size()
torch.Size([1, 9, 9, 1])
>>> shifted_x = torch.roll(x, shifts=(-1, -1), dims=(1, 2))
>>> xx=shifted_x.squeeze(3).squeeze(0)
>>> xx
tensor([[ 2, 3, 4, 5, 6, 7, 8, 9, 1],
[ 3, 4, 5, 6, 7, 8, 9, 10, 2],
[ 4, 5, 6, 7, 8, 9, 10, 11, 3],
[ 5, 6, 7, 8, 9, 10, 11, 12, 4],
[ 6, 7, 8, 9, 10, 11, 12, 13, 5],
[ 7, 8, 9, 10, 11, 12, 13, 14, 6],
[ 8, 9, 10, 11, 12, 13, 14, 15, 7],
[ 9, 10, 11, 12, 13, 14, 15, 16, 8],
[ 1, 2, 3, 4, 5, 6, 7, 8, 0]])
那為什么要加mask呢,它是由一個假設的,假設各個window之間不相關,各個window單獨做attention。其中我們以\(H=9,W=9, window\_size=3, shift\_size=1\)為例。
圖片參考:SWin Transformer
沒有進行shift的window划分圖:
再forward中,是會對輸入的x進行shift操作的:
shift后的操作:
其中,上圖黑線代表原來的邊界。每個彩色框代表經過shift操作后的window划分,可以看到每個彩色框內部黑線位置是一樣的;黑線是window的邊界。
>>> attn_mask
tensor([[[ 0., 0., -100., 0., 0., -100., -100., -100., -100.],
[ 0., 0., -100., 0., 0., -100., -100., -100., -100.],
[-100., -100., 0., -100., -100., 0., -100., -100., -100.],
[ 0., 0., -100., 0., 0., -100., -100., -100., -100.],
[ 0., 0., -100., 0., 0., -100., -100., -100., -100.],
[-100., -100., 0., -100., -100., 0., -100., -100., -100.],
[-100., -100., -100., -100., -100., -100., 0., 0., -100.],
[-100., -100., -100., -100., -100., -100., 0., 0., -100.],
[-100., -100., -100., -100., -100., -100., -100., -100., 0.]]])
我們舉個例子說明:
以第一個彩色框為例,第一行代表第一個元素是否可以看到對應位置的元素(0代表看得到,-100代表看不到)。當兩個像素點位於不同的window之中(黑線)就是看不到,就賦給一個負數,后面再做softmax,對應權重就會非常小。
5.3 Dual Up Sample
使用PixelShuffle和bilinear合起來的特征作為輸出。
5.4 Absolute position embedding
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
大概位置再SUNet的635行左右。
那為什么要加Absolute position embedding呢?
是因為再5.2 window attention中呢,是有一個relative position的,但是relative position作用范圍僅僅是在一個window里面,即每個window相同位置上的relative position都是一樣的。所以需要absolute position。
5.5 DownSampling
下采樣是通過Patch Embedding實現的。