github代碼:https://github.com/Chet1996/pytorch-UNet(如有幫助,點個星星hi!)
0 - Abstract
這篇文章是生物學會議ICMICCAI2015的文章,主要針對的是生物影像進行分割。由於普遍認為深度學習需要大量的樣本進行訓練,而生物醫學領域上的數據量比較少,所以本文提出了一種網絡和訓練策略,依靠數據增強等技巧有效的利用了有限的標簽信息。該體系結構包括捕捉上下文的收縮路徑(contracting path)和實現精確定位的對稱擴展路徑(symmetric expanding path)。實驗表明,該網絡結構可以在非常少的圖像數據集上進行端到端訓練。
1 - Introduction & Network Architecture
Ciresan等人使用滑動窗口,提高圍繞該像素的局部區域(補丁)作為輸入來預測每個像素的類別標簽。雖然該方法可以達到很好的精度,但是存在兩個缺點:
- 速度非常慢。因為網絡必須分別為每個補丁運行,並且由於補丁的重疊造成大量的冗余;
- 精確度和局部區域(補丁)大小的權衡。較大局部區域帶來更多的信息但需要更多的緩沖層(例如最大池化層)來處理,較小局部區域使得上下文信息變少。
本文提出的網絡,是全卷積網絡,其中主要是想是通過逐層擴充來補充通常的收縮網絡(下采樣),其中pooling被unsampling操作代替(稱之為上采樣),這些層能夠增加輸出的分辨率。為了精准定位每一個像素,下采樣和上采樣路徑中相同尺度的特征進行連接,整體架構如下圖:
從上圖可以看到,作者采用的3x3卷積的padding設置為0(unpadded的卷積),因此每一次卷積都會使得特征尺度在h和w上均減少2,從而使得,上采樣路徑得到的特征圖尺度和下采樣路徑的特征圖尺度不完全相同(下采樣的特征圖大於上采樣),所以需要先對下采樣特征圖進行裁剪之后再和上采樣特征圖進行連接(即圖中表述的copy and crop)。(我猜想,應該是當時文章發出的時候,計算力的限制,導致要求輸出和輸入具有同等大小的分辨率會犧牲很多的實效性,因此作者做了如此一個權衡。在實驗過程和我的認識中,如果對於特征圖進行裁剪,是會損失一些特征信息的,因此我在我的代碼中並沒有完全按照文章的做法復現,而是把輸入和輸出都統一到512x512的分辨率,並且3x3卷積的padding都設置為1)
2 - Training
- input/output:輸入為572x572分辨率的圖像,標簽為388x388的分割圖;
- batch size:為了最小化開銷並最大限度地利用GPU內存,我們傾向於使用大的輸入塊而不是大的批處理大小,從而將批處理減少到單個圖像,即batch_size設置為1;
- optimizer:SGD(隨機梯度下降)優化器,其momentum(動量)設置為0.99,使得幾乎所有之前訓練的樣本都能影響到當前訓練樣本的更新(我覺得就和batch size設置得比較大的效果應該是一樣的);
- criterion:交叉熵損失函數(但我在我的代碼實現中使用了sigmoid+BCELoss代替了交叉熵損失函數),作者通過預先計算每個真實分割的權重圖,來補償訓練集中不同類別的不同頻率,並迫使網絡學習我們的觸摸單元之間引入的小分離邊界。分離邊界使用形態學操作來計算,計算權重圖通過公式$w(x)=w_c(x)+w_0*exp(-\frac{(d_1(x)+d_2(x))^2}{2\sigma^2})$,其中$w_c$是權重圖用來平衡像素的頻率,$d_1$表示最近單元邊界的距離,$d_2$表示到第二進單元的邊界的距離,文中設置$w_0= 10, \sigma\approx 5pixels$(涉及到形態學和邊界的部分還沒有搞懂,后續需要補充);
- initialize:文中提出使用標准偏差為$\sqrt{\frac{2}{N}}$的高斯分布來初始化卷積網絡的kernel,其中N表示一個神經元輸入節點的數量,例如3x3的64通道的卷積層的$N= 3*3*64= 576$;
- data augmentation:文中主要使用移位、旋轉、變形、灰度值變化等數據增強方法,其中似乎是訓練樣本的隨機彈性變形是訓練具有很少標簽的分割網絡的關鍵。文中使用隨機位移矢量在粗糙的3x3網絡上生成平滑變形,位移從10像素便准偏差的高斯分布中采樣,然后使用雙三次插值計算每個像素位移。下采樣路徑末尾的dropout層執行進一步的隱式數據增強(這一點似乎網絡結構圖沒有體現,按文中的意思應該是在下采樣路徑末尾加入了dropout層從而防止過擬合而達到相當於圖像增強的效果)。
3 - My code
https://github.com/Chet1996/pytorch-UNet
我基於文中的思想和文中提到的EM segmentation challenge數據集大致復現了該網絡(github代碼)。其中為了代碼的簡潔方便,有幾點和文中提出的有所不同:
- 將輸入輸出統一到512x512(文中輸入為572x572,輸出為388x388);
- 將輸出的通道數改為1,而后接上sigmoid激活,再用BCELoss計算損失(文中輸出通道為2,而后通過softmax激活,再用交叉熵損失函數計算損失);
- 只采用隨機水平、垂直翻轉作為數據增強(文中采用了移位、旋轉、變形、灰度值變化等數據增強方法,並且似乎在下采樣路徑末尾加入了dropout層);
- 沒有引入文中所提到的分離邊界的技巧;
- 加入了batch normalize層。
我的訓練參數如下:
- train/val:將前28張圖片作為訓練集,后2張圖片作為驗證集;
- data augmentation:隨機水平、垂直翻轉;
- input/output:1x512x512;
- optimizer:SGD優化器,其中lr設置為0.01,momentum設置為0.99,weight_decay設置為0.0005;
- criterion:BCELoss;
- epochs:60;
- batch size:1;
- lr decay:每30個epoch衰減為原來lr的0.1;
- initialize:文中提到的初始化方法;
- batch normalize:在每一層卷積層后面加入了bn層。
訓練數據可視化如下圖:
訓練集和驗證集的loss變化曲線如下圖:
在驗證集上的預測效果如下圖(第一張圖為輸入圖片,第二張圖為標簽,第三張圖為網絡預測結果):
4 - 參考資料
https://blog.csdn.net/u014451076/article/details/79424233
https://blog.csdn.net/shine19930820/article/details/80098091
https://github.com/Chet1996/pytorch-UNet