Uncertainty-aware Self-ensembling Model for Semi-supervised 3D Left Atrium Segmentation(理解)


原文鏈接

掃碼關注下方公眾號:"Python編程與深度學習",領取配套學習資源,並有不定時深度學習相關文章及代碼分享。

 

今天分享一篇發表在MICCAI 2019上的論文:Uncertainty-aware Self-ensembling Model for Semi-supervised 3D Left Atrium Segmentation (原文鏈接:[1],代碼鏈接:[2])。

1 研究背景

訓練深度卷積神經網絡通常需要大量的標簽數據,然而對於醫學影像分割任務,大量數據的標注成本很高,因此考慮怎么同時利用好僅有的標簽數據和無標簽數據(半監督方法)在醫學影像處理中是非常重要的。這篇文章針對3D MR圖像的左心房分割任務提出了不確定性感知自增強模型,能夠更有效地利用無標簽數據從而獲得更好的性能。

2 方法

2.1 整體流程

如上圖(Fig.1)所示,對於有標簽數據,學生模型 (student model)進行有監督學習。對於無標簽數據,通過教師模型 (teacher model)預測分割圖,作為學生模型 (student model)的學習目標,並同時評估學習目標的不確定性。基於學習目標的不確定性,采用一致性損失函數提高學生模型的性能。

2.2 半監督分割 (Semi-supervised segmentation)

對於3D數據的半監督任務,假設有$N$個標簽數據和$M$個無標簽數據,那么有標簽數據集可以表示為$\mathcal{D}_L=\{(x_i,y_i)\}_{i=1}^N$,無標簽數據集可以表示為$\mathcal{D}_U=\{(x_i)\}_{i=N+1}^{N+M}$,其中$x_i\in\mathbb{R}^{H\times W\times D}$是輸入數據,$y_i\in \{0,1\}^{H\times W\times D}$是標簽數據。文中的半監督分割框架的學習目標為:
$$\min_{\theta}\sum_{i=1}^{N}\mathcal{L}_s(f(x_i;\theta))+\lambda\sum_{i=1}^{N+M}\mathcal{L}_c(f(x_i;\theta',\xi'),f(x_i;\theta,\xi))$$
其中$\mathcal{L}_s$為在有標簽數據上計算的有監督損失部分(交叉熵損失),$\mathcal{L}_c$為在無標簽數據上計算的教師模型和學生模型之間的無監督損失部分。$f(\cdot)$表示分割神經網絡,$(\theta',\xi')$和$(\theta,\xi)$分別表示教師模型和學生模型中的參數和的不同擾動(例如給輸入加入噪聲或者網絡中加入dropout)。$\lambda$是控制有監督損失部分和無監督損失部分的權衡參數。

此外,[9][14]中證明了集成網絡在不同訓練階段的預測結果能夠有效地提高預測結果,因此文中采用了指數移動平均 (exponential moving average, EMA)策略來提高教師模型的預測結果。具體地,教師模型的參數$\theta'$的更新策略為:
$$\theta_t'=\alpha\theta_{t-1}'+(1-\alpha)\theta_t$$
其中$\theta_t$是學生模型在第$t$次訓練迭代中的參數,$\alpha$是用來控制EMA更新速度的參數。

2.3 不確定性感知 (Uncertainty-Aware Mean Teacher Framework)

教師模型對於無標簽數據的預測結果是不確定性且有噪聲的,而這些預測結果將作為學生模型學習的一部分 ($\mathcal{L}_c$),因此作者設計了不確定性感知策略使得學生模型能夠逐漸學習更加可靠的目標。具體地,對於訓練圖像,教師模型不僅要預測它們的分割圖,還要評估它們的不確定性。然后學生模型在學習中只選取其中具有更低的不確定性(更加可靠)的數據計算一致性損失 (consistency loss)。

2.3.1 不確定性評估 (Uncertainty Estimation)

不確定評估是由教師模型生成的,具體有:
1. 對於每一個輸入數據,進行$T$次前向傳播獲得預測結果,每一次都隨機對輸入數據加入高斯噪聲或者在網絡中加入隨機dropout。因此每一個體素都有$T$個預測結果,可以表示為$\{\mathbf{p}_t\}_{t=1}^T$
2. 采用預測熵 (predictive entropy)大致近似不確定性,具體有:$\mu_c=\frac{1}{T}\sum_t\mathbf{p}^c_t$,$u=-\sum_{c}\mu_clog\mu_c$,其中$\mathbf{p}_t^c$是對在第$t$次前向傳播中對屬於第$c$類別概率的預測。最終可以構成一個不確定性張量$U,\{u\}\in\mathbb{R}^{H\times W\times D}$

2.3.2 基於不確定性的一致性損失函數 (Uncertainty-Aware Consistency Loss)

有了上一步的教師模型預測的不確定性結果$U$,可以過濾掉相對不確定的預測,而選取相對可靠的預測作為學生模型的學習目標。具體如下:
$$\mathcal{L}_c(f',f)=\frac{\sum_v\mathbb{I}(u_v<{H})\left \|f_v'-f_v\right \|^2}{\sum_v\mathbb{I}(u_v<{H})}$$
其中$\mathbb{I}$是指示函數,如果條件成立則返回1,否則返回0,用以篩選出可靠的樣本。$f_v'$和$f_v$分別是教師網絡和學生網絡在第$v$個體素位置的預測結果。$u_v$是不確定性張量$U$在第$v$個體素上的值,$H$是過濾不確定預測的閾值。作者提到,加入了基於不確定性的一致性損失函數,能夠同時提高教師模型和學生模型的性能。

3 實驗結果

這里我只給出論文中的部分實驗結果,具體的實驗結果分析以及實驗和參數的設置請看原文。

4 參考資料

[1] https://arxiv.org/pdf/1907.07034

[2] https://github.com/yulequan/UA-MT

[3] Laine, S., Aila, T.: Temporal ensembling for semi-supervised learning. arXiv preprint (2016)

[4] Tarvainen, A., Valpola, H.: Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. In: NIPS (2017)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM