摘要
CNN由於卷積操作的局部性,難以學習全局和長范圍的語義信息。交互。 提出swin-unet,是一個像Unet的純transformer,用於醫學圖像分割。采用層級的帶移動窗口的swin transformer作為編碼器,提取上下文特征。一個對稱的、帶有patch展開層的、基於swin-transformer的解碼器用於上采樣操作,恢復特征圖的空間分辨率。 在直接下采樣輸入和上采樣輸出4倍時,在多器官和心臟分割任務上證明,提出的網絡超過了全卷積或卷積和transformer的結合方法。模型和代碼將公開在:https://github.com/HuCaoFighting/Swin-Unet
方法
結構概覽
- patch大小為4x4,因此每個patch的特征維度是4x4x3=48,之后應用線性嵌入層投影特征維度到任意維度(表示為C)。
- 轉換的patch token經過幾個swin transformer塊和patch合並層,以產生層級特征表示。patch合並層用於下采樣和升維,swin transformer塊用於特征表示學習。
- 受unet啟發,設計對稱的基於transformer的解碼器,由swin transformer塊和patch擴展層組成。提取的上下文特征通過跳躍連接於編碼器的多尺度特征融合,以補償下采樣造成的空間信息損失
-
相比於patch合並層,一個patch擴展層變形毗鄰維度的特征圖到一個2倍上采樣的大特征圖。最后的patch擴展層用於直徑4倍的上采樣,以恢復特征圖的分辨率到輸入大小(wxh),然后線性投影層用於這些上采樣的特征,輸出像素級別的分割預測
swin transformer塊
swin transformer基於移動窗口構建,圖2中展示了兩個連續的swin transformer塊,
每個swin transformer塊由LN層、多頭注意力模塊、殘差連接和2層的帶有GELU的MLP組成。基於窗口的多頭注意力(W-MSA)和基於移動窗口的多頭自注意力模塊(SW-MSA),用於后續的兩個transformer塊。基於這個窗口切分機制,連續的swin transformer塊可以表示為:
自注意力計算方法為:
M2代表patch數量,d代表q或k的維度。B中的值來源於偏置矩陣。
Encoder
- C維的分辨率維H/4 x W/4的標記化輸入輸入到兩個連續的Swin Transformer塊中進行表示學習,特征維度和分辨率不變。同時,patch合並層將縮減token的數量(2倍下采樣),提升特征維度到原始維度的2倍。這個步驟在編碼器中重復3次。
- patch合並層:輸入patch分成4部分,通過patch合並層拼接在一起,采用這個過程,特征分辨率將下采樣2倍,因為拼接操作導致特征維度提升4倍,一個線性層用於拼接的特征,將特征維數統一為2×原始維數
BottleNeck
只有兩個連續swin transformer塊用於構建bottleneck,學習深度特征表示,bottleneck中特征維度和分辨率不變
Decoder
- 基於swin transformer塊構建,在解碼器中采用patch擴展層,用於刪改楊提取的特征,patch擴展層將相鄰維度的特征圖變為更高分辨率(2倍上采樣),維度數量減半。
- patch 擴展層:以第一個patch擴展層為例,上采樣前,一個線性層應用到輸入特征(W/32 x H/32 X 8C)以增加特征維度到原來的2倍(W/32 x H/32 X 16C)。然后,采用重排操作,擴展原輸入特征的分辨率2倍,同時縮減特征維度到輸入維度的1/4(W/32 × H/32 × 16C -> W/16 × H/16 × 4C)
Skip Connection
將淺層特征和深度特征拼接到一起,以縮減下采樣引起的空間分辨率損失。后跟1個線性層,拼接的特征維度與上采樣的特征維度相同
實驗
公眾號