GRDN:分組殘差密集網絡,用於真實圖像降噪和基於GAN的真實世界噪聲建模


GRDN:分組殘差密集網絡,用於真實圖像降噪和基於GAN的真實世界噪聲建模

摘要

隨着深度學習體系結構(尤其是卷積神經網絡)的發展,有關圖像去噪的最新研究已經取得了進展。但是,現實世界中的圖像去噪仍然非常具有挑戰性,因為不可能獲得理想的地面對圖像和現實世界中的噪聲圖像對。由於最近發布了基准數據集,圖像去噪社區的興趣正朝着現實世界中的去噪問題發展。在本文中,我們提出了分組殘差密集網絡(GRDN),它是最新的殘差密集網絡(RDN)的擴展和通用體系結構。RDN的核心部分定義為分組殘差密集塊(GRDB),並用作GRDN的構建模塊。我們通過實驗表明,通過級聯GRDB可以顯着改善圖像降噪性能。除了網絡架構設計之外,我們還開發了一種新的基於對抗網絡的真實世界的噪聲建模方法。我們通過在NTIRE2019實像去噪挑戰賽道2:sRGB中的峰值信噪比和結構相似性方面獲得最高分,證明了所提出方法的優越性。.

一、簡介

在圖像去噪領域,最近的研究表明,基於學習的方法比以前的手工方法(例如塊匹配3D(BM3D)[6]及其變體)更加有效。對於基於學習的方法,擁有足夠數量的高質量數據集至關重要。由於可以通過在無噪聲圖像上添加合成噪聲來輕松構建一對嘈雜且無噪聲的圖像,因此,大多數以前的基於學習的方法都專注於經典的高斯去噪任務,並且最關注網絡的體系結構設計。尤其是卷積神經網絡(CNN)。然而,由於合成噪聲圖像和真實噪聲圖像之間的差距,發現使用合成圖像訓練的CNN在真實噪聲圖像上表現不佳,有時甚至不如BM3D [22]。

這些作者做出了同樣的貢獻。通訊作者:S.-W.榮格*

對於現實世界的圖像去噪,主要有兩種方法。第一種方法是找到一種更好的真實噪聲統計模型,而不是加性高斯白噪聲[3,8,10,19,23]。特別是,高斯分布和泊松分布的組合顯示出可以緊密地模擬依賴信號和不依賴信號的噪聲。使用這些新的合成噪點圖像訓練的網絡證明了在消除現實世界的噪點圖像方面的優越性。這種方法的一個明顯優勢是,只需將合成噪聲添加到無噪聲的地面圖像中,我們就可以擁有無限多的訓練圖像對。但是,是否可以通過統計模型來模擬現實世界的噪聲仍然是有爭議的。因此,第二種方法是相反的方向。從真實的嘈雜圖像中,可以通過反轉圖像采集過程[1、4、24、22、2]獲得幾乎無噪聲的地面真實圖像。據我們所知,智能手機圖像去噪數據集(SIDD)[1]是第二種方法中最大的高質量圖像數據集之一。但是,提供的圖像數量可能不足以訓練大型網絡,並且沒有足夠的專業知識,很難從真實的嘈雜圖像中生成真實的圖像。因此,我們采用第二種方法,但應用了我們自己的基於生成對抗網絡(GAN)的數據增強技術來獲取更大的數據集。

網絡架構當然是最重要的。在基於CNN的圖像恢復中,密集殘差塊(RDB)[33、32]受到了極大關注。在本文中,我們提出了一種稱為分組殘差密集網絡(GRDN)的新體系結構。特別是,提出的體系結構采用了最近的殘差密集網絡(RDN)作為具有較小修改的組件,並將其定義為分組的殘差密集塊(GRDB)。通過將GRDB與關注模塊進行級聯,我們可以獲得現實世界中圖像去噪任務的最新性能[28]。在NTIRE2019實像去噪挑戰-軌道2:sRGB中,我們在39.93 dB的峰值信噪比(PSNR)和0.9736的結構相似度(SSIM)方面取得了最佳性能。

img

​ 圖1:提出的網絡架構:GRDN

二、相關工作

2.1 影像還原

圖像降噪是圖像處理中研究最廣泛的主題之一。由於深度學習的顯着進步,基於CNN的方法現在在圖像去噪中占主導地位。但是,大多數以前的基於學習的圖像去噪方法都集中在經典的高斯去噪任務上。對於現實世界的圖像降噪,第一種方法是通過使用不同的相機設置來捕獲一對嘈雜且無噪點的圖像[2,22]。在[22]中表明,較早的基於學習的方法與經典方法(如BM3D)可比甚至有時不如BM3D。我們認為這主要是由於訓練數據集的質量和數量不足。因此,開發了更豐富和完善的數據集,例如Darmstadt噪聲數據集(DND)和SIDD [1],並且最近的基於學習的方法[1、3、10、23]顯示出它們優於經典方法在現實世界中的圖像去噪。

除了努力生成高質量的數據集以外,還進行了大量研究以找到更好的網絡體系結構以進行圖像去噪。從CNN的角度來看,為不同的圖像恢復任務(如圖像去噪,圖像去模糊,超分辨率和壓縮偽像減少)開發的網絡體系結構具有相似性。反復證明,為某種圖像還原任務開發的一種體系結構在其他還原任務中也表現良好[30、32、23]。因此,我們檢查了為不同圖像恢復任務而開發的許多體系結構,尤其是超分辨率[7,13,16,17,14,26,33,31,11,18]。其中,RDN [33,32]和殘留信道關注網絡(RCAN)[31]與我們的網絡體系結構關系最密切。

特別是,我們嘗試利用RDN和RCAN中的新穎思想。RCAN在殘差(RIR)體系結構中引入了殘差,消融研究表明RIR的性能增益最為顯着。因此,我們在架構設計中使用了RIR原理。另外,RDN本身是一個圖像恢復網絡,但是我們將它與修改一起用作我們的網絡的組成部分,並構造了一個RDN的級聯結構作為我們的圖像去噪網絡。最近的研究還表明注意模塊的有效性。在許多注意力模塊中,卷積塊注意力模塊(CBAM)[28]是一種易於植入的模塊,可以順序地估計通道的注意力和空間的注意力,在一般物體檢測和圖像分類中顯示出了效率,因此我們將CBAM納入了我們的網絡。

2.2. GAN

諸如SIDD和DND之類的可公開獲得的真實世界圖像降噪數據集中的訓練圖像數量可能不足以訓練深度和廣泛的神經網絡。擴充這些數據集的一種可行方法是利用GAN的功能[9]。第一種基於GAN的真實世界的噪聲建模方法[5]僅使用真實世界的噪聲圖像訓練噪聲生成器,其中鑒別器被訓練為區分真實和模擬噪聲信號。然后,使用噪聲發生器將合成的但逼真的噪聲添加到無噪聲的地面圖像中,並使用生成的成對的地面圖像和高噪聲圖像最終訓練去噪網絡。通過使用GAN生成的數據集,現實世界中的圖像降噪性能得到了顯着改善。

通過將諸如無噪聲圖像補丁,ISO和快門速度之類的調節信號作為生成器的附加輸入,我們改進了以前基於GAN的實際噪聲仿真技術[5]。對無噪聲圖像斑塊進行調節可以幫助生成更逼真的與信號相關的噪聲,而其他相機參數可以提高可控性和各種模擬噪聲信號。我們還通過使用最新的相對論GAN [12]來更改先前體系結構的鑒別符[5]。與常規GAN不同,相對論GAN的判別器學會了確定真實數據與偽數據之間哪個更為現實。我們的方法與傳統相對論GAN的不同之處在於,真實數據和偽數據都被用作輸入,以使鑒別器更明確地比較這兩個數據。

img

​ 圖2:GRDN的組件:(a)RDB和(b)GRDB

三、提出的方法

3.1 圖像去噪網絡

我們的稱為GRDN的圖像去噪網絡架構如圖1所示。我們的設計原則是分配每一層的負擔,以便可以更好地訓練更深更廣的網絡。為此,將殘余連接應用於四個不同級別。下采樣層和上采樣層被包括在內,以實現更深,更寬的架構,並且還應用了CBAM [28]。

受RDN [33]的啟發,我們使用如圖2(a)所示的RDB作為構建模塊。在RDN中,來自級聯RDB的要素被串聯在一起,然后是1×1卷積層。如圖2(b)所示,我們將RDN的功能串聯部分定義為GRDB,並將其用作GRDN的構建模塊。請注意,原始RDN [33]在GRDB之前和之后應用卷積層,並使用全局殘差學習進行圖像去噪。但是,我們認為RDN給GRDB的最后1×1卷積層帶來了沉重的負擔。因此,我們改為級聯GRDB,以便可以將RDB中的功能分為多個階段。受包括RDN [33]在內的許多最新圖像恢復網絡的推動,我們還包括了全局殘差連接,因此該網絡可以專注於學習噪聲圖像和真實圖像之間的差異。最后,我們將CBAM作為構建模塊來進一步提高去噪性能。CBAM塊的位置根據經驗選擇在上卷積層和最后一個卷積層之間。

img

​ 圖3:cERGAN生成器

盡管GRDN在結構上比RDN更深[33,32],但我們使用了相同數量的RDB。具體來說,在原始RDN中使用了16個RDB進行圖像降噪。我們使用4個GRDB堆棧,每個GRDB包含4個RDB,因此GRDN中有16個RDB。

3.2 基於GAN的真實世界噪聲建模

受最新技術[5]的啟發,我們開發了自己的發生器和鑒別器用於實際噪聲建模。與先前的技術[21]類似,我們使用殘差塊(ResBlocks)作為生成器的構建模塊。但是,我們進行了一些修改以提高實際噪聲建模的性能。圖3顯示了生成器架構。首先,我們包含調節信號:無噪聲的圖像補丁,ISO,快門速度和智能手機型號,作為發生器的附加輸入。對無噪聲圖像斑塊進行調節可以幫助生成更逼真的與信號相關的噪聲,而其他與相機相關的參數可以提高可控性和各種模擬噪聲信號。為了用這些調節信號訓練生成器,我們使用了SIDD [1]的元數據。第二,頻譜歸一化(SN)[20]在像[29]中所使用的基本卷積單元中在批量歸一化之前應用。第三,我們的ResBlock包含剩余縮放比例[25、18、27]。從經驗上發現,SN和殘留水垢對訓練我們的發電機很有用。

img

​ 圖4:CERGAN鑒別器

如圖4所示,我們的鑒別器架構也不同於以前的基於GAN的噪聲仿真技術[5]。增強的超分辨率GAN(ESGAN)[27]表明相對論GAN [12]可有效地生成逼真的圖像紋理。與原始GAN [9]不同,相對論GAN的判別者學會了確定真實數據與偽數據之間哪個更為現實。令img表示輸入圖像x的未變換鑒別器輸出。然后可以將標准鑒別符表示為img,σ是S型函數。ESGAN中采用的相對論平均GAN(RaGAN)的定義為:

img

其中,imgimg分別表示真實數據和偽數據,而img表示期望運算符,該期望運算符應用於迷你批處理中的所有數據[27]。定義為條件顯式相對論GAN(cERGAN)的擬議網絡的鑒別符為

img

其中img表示調節信號。具體來說,我們通過復制值使每個條件數據的大小與訓練補丁的大小相同,因此我們的img由4個補丁組成:來自智能手機代碼的3個常量補丁(例如Google Pixel = 0,iPhone 7 = 1等),ISO級別,快門速度和一個無噪點的圖像補丁。除了img之外,我們還同時使用imgimg作為鑒別符的輸入。請注意,ESGAN使用imgimg作為鑒別符的輸入。

生成器和鑒別器的損失函數分別表示為imgimg,其定義如下:

img

換句話說,如果第二個輸入是img,而第三個輸入是img,則辨別器將經過訓練以預測接近1的值,即imgimg更現實。如果切換了兩個輸入,則鑒別器將被訓練為預測接近0的值,即img的真實性不如img。訓練了生成器以欺騙鑒別器。通過要求網絡在真實數據和假數據之間進行顯式比較,我們可以模擬更真實的真實噪聲。

四、實驗

我們使用PyTorch庫,Intel i7-8700 @ 3.20GHz,32GB RAM和NVIDIA Titan XP來實現所有模型。

img

​ 表1:圖像去噪模型的比較。

4.1 數據集

我們使用了NTIRE 2019實像去噪挑戰的訓練和驗證圖像,它是SIDD數據集的子集[1]。讓ChDB表示我們用於實驗的數據集。具體來說,分別使用320個高分辨率圖像和1280個尺寸為256×256的裁剪圖像塊進行訓練和驗證。提供的圖像是由五個智能手機相機拍攝的-Apple iPhone 7,Google Pixel,三星Galaxy S6 Edge,摩托羅拉Nexus 6和LG G4。由於測試數據集的真實圖像不公開,因此我們在本節中使用驗證數據集報告圖像去噪模型的性能。由於我們注意到地面真實圖像中圖像邊界周圍的非邊際劣化,因此在生成訓練補丁時,我們排除了第一行和最后8行/列。沒有應用諸如縮放,翻轉和旋轉之類的常規數據增強技術。

4.2 圖像去噪

4.2.1 實施細節

我們通過兩種方式擴充了提供的訓練數據集。首先,我們使用作者提供的源代碼[10]將合成噪聲添加到真實圖像中。我們還應用了第3.2節中介紹的基於GAN的噪聲模擬器,以生成其他合成噪聲圖像。

在每個訓練批次中,我們隨機提取16對真實的圖像和嘈雜的圖像塊。我們使用Adam [15]進行了訓練,其img = 0.9,img = 0.999。初始學習率設置為img,然后在每次img迭代時降低到一半。我們使用img損失訓練了網絡。我們訓練了大約5天的模型。

對於上/下卷積層,我們使用了4×4過濾器,對來自RDB的級聯特征使用了1×1過濾器。否則,我們使用3×3濾波器。使用零填充,並且未對所有卷積層使用膨脹。每個RDB具有8對卷積層和ReLU激活層。

4.2.2與RDN的比較

首先,我們將GRDN模型與RDN [32]進行了比較。實驗結果如表1所示。我們使用ChDB重新訓練了RDN。表1中的第一和第二列對應於RDN和建議的GRDN。可以看出,我們模型的PSNR比RDN高0.04 dB。請注意,RDN和GRDN具有相同數量的RDB,因此參數的數量相似。具體來說,我們的基本GRDN模型具有22M參數,而RDN具有21.9M參數。

4.2.3 補丁大小實驗

由於原始圖像分辨率很高(超過1200萬像素),因此需要使用最大可能的色塊大小來包含足夠的圖像內容。因此,我們將補丁大小增加到96×96,這是我們實驗環境中最大的大小。通過比較表1的第2列和第5列,我們可以看到,通過增大貼片尺寸可以獲得0.22dB的顯着性能增益。

4.2.4 CBAM模塊的實驗

CBAM [28]是一個簡單但有效的CNN模塊。因為它是輕量級的通用模塊,所以可以輕松地將其植入任何CNN架構中,而無需大量增加參數數量。特別是,CBAM可以置於網絡的瓶頸處。由於我們具有下采樣和上采樣層,因此我們檢查了CBAM的不同位置和組合。我們得出結論,對於我們的模型,CBAM的最佳位置是在上采樣層之后。我們認為,這表明CBAM增強了上采樣數據的重要功能。它還有助於為之后的最后一個卷積層構造最終的去噪圖像。發現CBAM的有效性取決於網絡的復雜性。比較表1的第二列和第三列,CBAM將PSNR提高了0.05 dB。但是,在增加了補丁大小之后,CBAM的增益被稀釋了。比較表1的第5列和第6列,CBAM甚至將PSNR降低了0.01 dB。

img

圖5:真實世界噪聲建模的實驗結果:

(a)ChDB中的真實噪聲圖像塊

(b)真實圖像塊

(c)噪聲圖像和真實圖像塊之間的差異

( d)cERGAN產生的噪聲補丁

4.2.5超參數調整

我們比較了具有不同數量的過濾器和GRDB的網絡。比較表1的第6列和第7列,較淺但較寬的網絡性能要好0.02 dB。因此,在我們的硬件限制下,第七列的模型是性能最佳的模型。

4.3 真實世界噪聲建模

為了訓練cERGAN的生成器和判別器,我們從真實的噪聲圖像中裁剪了尺寸為48×48的圖像塊,並從ChDB中裁剪了其真實的圖像。我們將批處理大小為32,將Adam優化器與imgimg一起使用。生成器和鑒別器經過了340k次迭代訓練。區分器和生成器的初始學習率均設置為0.0002,並且我們在320k次迭代后線性降低了學習率,以使最后一次迭代后學習率變為0。圖5示出了由所提出的cERGAN產生的一些噪聲圖像補丁。從圖5和6中可以看出。如圖5(c)和(d)所示,建議的cERGAN可以生成接近真實噪聲的噪聲補丁。

通過比較使用或不使用模擬數據訓練的擬議圖像降噪網絡,可以評估模擬噪聲圖像的有效性。在這里,測試的網絡對應於表1的第4列。我們首先嘗試僅使用cERGAN獲得的合成的現實噪聲圖像訓練圖像去噪網絡。在ChDB驗證集中獲得的平均PSNR為38.63 dB,不如我們僅使用提供的ChDB數據集獲得的PSNR(表1中的39.62 dB)。

img

​ 圖6:具有不同數據集的圖像去噪網絡的收斂性分析。

將統計建模的真實世界的噪聲添加到ChDB的真實圖像中。我們使用這些數據集訓練的圖像去噪網絡僅產生36.17 dB,這表明所提出的基於GAN的噪聲建模至少比統計噪聲建模方法[10]表現更好。

最后,我們將原始的ChDB數據集與通過建議的cERGAN和常規方法生成的合成數據集[10]相結合。在這里,我們只能測試一種配置:來自ChDB的90%,使用[10]的模擬ChDB的5%和使用cERGAN的模擬ChDB的5%。圖6顯示使用增強數據集獲得的PSNR更加穩定地增加。所得的PSNR為39.64 dB,略高於使用原始數據集獲得的PSNR(39.62 dB)。

五. NTIRE2019圖像降噪挑戰

這項工作被提議參加NTIRE2019實像去噪挑戰-Track 2:sRGB。挑戰在於開發一種具有最高PSNR和SSIM的圖像去噪系統。提交的圖像去噪網絡對應於表1的第七列。提交的模型中的一個小更改是,我們每2個GRDB都包含了跳過連接。對於訓練,(35.49 / 0.9812)(29.86 / 0.9314)(26.32 / 0.7576)(19.05 / 0.3623)(39.11 / 0.9899)(39.59 / 0.9902)(37.05 / 0.9749)(37.13 / 0.9748)我們使用增強的ChDB本節中提到的技術。 4.3。我們的模型在PSNR和SSIM方面均在真實圖像降噪方面排名第一。如表2所示,我們的模型優於第二等級方法0.05 dB。

六、結論

在本文中,我們提出了一種用於現實圖像降噪的改進網絡架構。通過廣泛和分層地使用剩余連接,我們的模型獲得了最先進的性能。此外,我們開發了一種改進的基於GAN的實際噪聲建模方法。

盡管我們只能將擬議的網絡評估為現實世界中的圖像降噪,但我們認為擬議的網絡普遍適用。因此,我們計划將提出的圖像去噪網絡應用於其他圖像恢復任務。我們也不能完全和定量地證明所提出的實際噪聲建模方法的有效性。為了更好地進行真實的噪聲建模,顯然必須進行更精細的設計。我們相信,我們的真實世界噪聲建模方法可以擴展到其他真實世界的退化,例如模糊,混疊和霧度,這將在我們的未來工作中得到證明。

七、參考文獻

[1] A. Abdelhamed, S. Lin, and M. Brown. A high-quality denoising dataset for smartphone cameras. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 1692–1700, 2018.

[2] J. Anaya and A. Barbu. RENOIR - A benchmark dataset for real noise reduction evaluation. CoRR, abs/1409.8230, 2014.

[3] T. Brooks, B. Mildenhall, T. Xue, J. Chen, D. Sharlet, and J. T. Barron. Unprocessing images for learned raw denoising. CoRR, abs/1811.11127, 2018.

[4] C. Chen, Q. Chen, J. Xu, and V. Koltun. Learning to see in the dark. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 3291–3300, 2018.

[5] J. Chen, J. Chen, H. Chao, and M. Yang. Image blind denoising with generative adversarial network based noise modeling. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 3155–3164, 2018.

[6] K. Dabov, A. Foi, V. Katkovnik, and K. Egiazarian. Image denoising by sparse 3-d transform-domain collaborative filtering. IEEE Trans. Image Process., 16(8):2080–2095, Aug. 2007.

[7] C. Dong, C. C. Loy, K. He, and X. Tang. Learning a deep convolutional network for image super-resolution. In Proceedings of the European Conference on Computer Vision, pages 184–199. Springer, 2014.

[8] A. Foi, M. Trimeche, V. Katkovnik, and K. Egiazarian. Practical Poissonian-Gaussian noise modeling and fitting for single-image raw-data. IEEE Trans. Image Process., 17(10):1737–1754, 2008.

GRDN網絡結構代碼實現

SubNets.py

import torch
import torch.nn as nn
import torch.nn.functional as F


def weights_init(m):
    """
    custom weights initialization called on netG and netD
    https://github.com/pytorch/examples/blob/master/dcgan/main.py
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

####################################################################################################################


class make_dense(nn.Module):
    def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
        super(make_dense, self).__init__()
        self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                              bias=False)
        self.nChannels = nChannels

    def forward(self, x):
        out = F.relu(self.conv(x))
        out = torch.cat((x, out), 1)
        return out

class make_residual_dense_ver1(nn.Module):
    def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
        super(make_residual_dense_ver1, self).__init__()
        self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                              bias=False)
        self.nChannels_ = nChannels_
        self.nChannels = nChannels
        self.growthrate = growthRate

    def forward(self, x):
        # print('1', x.shape, self.nChannels, self.nChannels_, self.growthrate)
        # print('2', outoflayer.shape)
        # print('3', out.shape, outoflayer.shape)
        # print('4', out.shape)

        outoflayer = F.relu(self.conv(x))
        out = torch.cat((x[:, :self.nChannels, :, :] + outoflayer, x[:, self.nChannels:, :, :]), 1)
        out = torch.cat((out, outoflayer), 1)
        return out

class make_residual_dense_ver2(nn.Module):
    def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
        super(make_residual_dense_ver2, self).__init__()
        if nChannels == nChannels_ :
            self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                                  bias=False)
        else:
            self.conv = nn.Conv2d(nChannels_ + growthRate, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                                  bias=False)

        self.nChannels_ = nChannels_
        self.nChannels = nChannels
        self.growthrate = growthRate

    def forward(self, x):
        # print('1', x.shape, self.nChannels, self.nChannels_, self.growthrate)
        # print('2', outoflayer.shape)
        # print('3', out.shape, outoflayer.shape)
        # print('4', out.shape)

        outoflayer = F.relu(self.conv(x))
        if x.shape[1] == self.nChannels:
            out = torch.cat((x, x + outoflayer), 1)
        else:
            out = torch.cat((x[:, :self.nChannels, :, :], x[:, self.nChannels:self.nChannels + self.growthrate, :, :] + outoflayer, x[:, self.nChannels + self.growthrate:, :, :]), 1)
        out = torch.cat((out, outoflayer), 1)
        return out

class make_dense_LReLU(nn.Module):
    def __init__(self, nChannels, growthRate, kernel_size=3):
        super(make_dense_LReLU, self).__init__()
        self.conv = nn.Conv2d(nChannels, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                              bias=False)

    def forward(self, x):
        out = F.leaky_relu(self.conv(x))
        out = torch.cat((x, out), 1)
        return out


# Residual dense block (RDB) architecture
class RDB(nn.Module):
    """
    https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
    """

    def __init__(self, nChannels, nDenselayer, growthRate):
        """
        :param nChannels: input feature 의 channel 수
        :param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
        :param growthRate: Conv 의 output layer 의 수
        """
        super(RDB, self).__init__()
        nChannels_ = nChannels
        modules = []
        for i in range(nDenselayer):
            modules.append(make_dense(nChannels, nChannels_, growthRate))
            nChannels_ += growthRate
        self.dense_layers = nn.Sequential(*modules)

        ###################kingrdb ver2##############################################
        # self.conv_1x1 = nn.Conv2d(nChannels_ + growthRate, nChannels, kernel_size=1, padding=0, bias=False)
        ###################else######################################################
        self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)

    def forward(self, x):
        out = self.dense_layers(x)
        out = self.conv_1x1(out)
        # local residual 구조
        out = out + x
        return out

def RDB_Blocks(channels, size):
    bundle = []
    for i in range(size):
        bundle.append(RDB(channels, nDenselayer=8, growthRate=64))  # RDB(input channels,
    return nn.Sequential(*bundle)

####################################################################################################################
# Group of Residual dense block (GRDB) architecture
class GRDB(nn.Module):
    """
    https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
    """

    def __init__(self, numofkernels, nDenselayer, growthRate, numforrg):
        """
        :param nChannels: input feature 의 channel 수
        :param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
        :param growthRate: Conv 의 output layer 의 수
        """
        super(GRDB, self).__init__()

        modules = []
        for i in range(numforrg):
            modules.append(RDB(numofkernels, nDenselayer=nDenselayer, growthRate=growthRate))
        self.rdbs = nn.Sequential(*modules)
        self.conv_1x1 = nn.Conv2d(numofkernels * numforrg, numofkernels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        out = x
        outputlist = []
        for rdb in self.rdbs:
            output = rdb(out)
            outputlist.append(output)
            out = output
        concat = torch.cat(outputlist, 1)
        out = x + self.conv_1x1(concat)
        return out

# Group of group of Residual dense block (GRDB) architecture
class GGRDB(nn.Module):
    """
    https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
    """

    def __init__(self, numofmodules, numofkernels, nDenselayer, growthRate, numforrg):
        """
        :param nChannels: input feature 의 channel 수
        :param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
        :param growthRate: Conv 의 output layer 의 수
        """
        super(GGRDB, self).__init__()

        modules = []
        for i in range(numofmodules):
            modules.append(GRDB(numofkernels, nDenselayer=nDenselayer, growthRate=growthRate, numforrg=numforrg))
        self.grdbs = nn.Sequential(*modules)

    def forward(self, x):
        output = x
        for grdb in self.grdbs:
            output = grdb(output)

        return x + output

####################################################################################################################


class ResidualBlock(nn.Module):
    """
    one_to_many 논문에서 제시된 resunit 구조
    """
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        residual = self.bn1(x)
        residual = self.relu1(residual)
        residual = self.conv1(residual)
        residual = self.bn2(residual)
        residual = self.relu2(residual)
        residual = self.conv2(residual)
        return x + residual


def ResidualBlocks(channels, size):
    bundle = []
    for i in range(size):
        bundle.append(ResidualBlock(channels))
    return nn.Sequential(*bundle)

DenoisingMoels.py

from models.subNets import *
from models.cbam import *


class ntire_rdb_gd_rir_ver1(nn.Module):
    def __init__(self, input_channel, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64, t=1):
        super(ntire_rdb_gd_rir_ver1, self).__init__()

        self.numforrg = numforrg  # num of rdb units in one residual group
        self.numofrdb = numofrdb  # num of all rdb units
        self.nDenselayer = numofconv
        self.numofkernels = numoffilters
        self.t = t

        self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
        # self.layer2 = nn.ReLU()
        self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        modules = []
        for i in range(self.numofrdb // self.numforrg):
            modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
        self.rglayer = nn.Sequential(*modules)

        self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        # self.layer8 = nn.ReLU()
        self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
        self.cbam = CBAM(self.numofkernels, 16)

    def forward(self, x):
        out = self.layer1(x)
        # out = self.layer2(out)
        out = self.layer3(out)

        # out = self.rglayer(out)
        for grdb in self.rglayer:
            for i in range(self.t):
                out = grdb(out)

        out = self.layer7(out)
        out = self.cbam(out)

        # out = self.layer8(out)
        out = self.layer9(out)

        # global residual 구조
        return out + x

class ntire_rdb_gd_rir_ver2(nn.Module):
    def __init__(self, input_channel, numofmodules=2, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64, t=1):
        super(ntire_rdb_gd_rir_ver2, self).__init__()

        self.numofmodules = numofmodules # num of modules to make residual
        self.numforrg = numforrg  # num of rdb units in one residual group
        self.numofrdb = numofrdb  # num of all rdb units
        self.nDenselayer = numofconv
        self.numofkernels = numoffilters
        self.t = t

        self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
        # self.layer2 = nn.ReLU()
        self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        modules = []
        for i in range(self.numofrdb // (self.numofmodules * self.numforrg)):
            modules.append(GGRDB(self.numofmodules, self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
        for i in range((self.numofrdb % (self.numofmodules * self.numforrg)) // self.numforrg):
            modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
        self.rglayer = nn.Sequential(*modules)

        self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        # self.layer8 = nn.ReLU()
        self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
        self.cbam = CBAM(numoffilters, 16)

    def forward(self, x):
        out = self.layer1(x)
        # out = self.layer2(out)
        out = self.layer3(out)

        for grdb in self.rglayer:
            for i in range(self.t):
                out = grdb(out)

        out = self.layer7(out)
        out = self.cbam(out)

        # out = self.layer8(out)
        out = self.layer9(out)

        # global residual 구조
        return out + x



class Generator_one2many_gd_rir_old(nn.Module):
    def __init__(self, input_channel, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64):
        super(Generator_one2many_gd_rir_old, self).__init__()

        self.numforrg = numforrg  # num of rdb units in one residual group
        self.numofrdb = numofrdb  # num of all rdb units
        self.nDenselayer = numofconv
        self.numofkernels = numoffilters

        self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
        self.layer2 = nn.ReLU()
        self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        modules = []
        for i in range(self.numofrdb // self.numforrg):
            modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
        self.rglayer = nn.Sequential(*modules)

        self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)
        self.layer8 = nn.ReLU()
        self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
        self.cbam = CBAM(self.numofkernels, 16)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)

        out = self.rglayer(out)

        out = self.layer7(out)
        out = self.cbam(out)
        out = self.layer8(out)
        out = self.layer9(out)

        # global residual 구조
        return out + x

cbma.py

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=False, bn=False, bias=True):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

def weights_init_rcan(m):
    """
    custom weights initialization called on netG and netD
    https://github.com/pytorch/examples/blob/master/dcgan/main.py
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        if classname.find('BasicConv') != -1:
            m.conv.weight.data.normal_(0.0, 0.02)
            if m.bn != None:
                m.bn.bias.data.fill_(0)
        else:
            m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

DGU-3DMlab1_track1.py

import numpy as np
import cv2
import torch
from models.DenoisingModels import *
from utils.utils import *
from utils.transforms import *
import scipy.io as sio
import time
import tqdm

if __name__ == '__main__':

    print('********************Test code for NTIRE challenge******************')

    # path of input .mat file
    mat_dir = 'mats/BenchmarkNoisyBlocksRaw.mat'

    # Read .mat file
    mat_file = sio.loadmat(mat_dir)

    # get input numpy
    noisyblock = mat_file['BenchmarkNoisyBlocksRaw']
    
    print('input shape', noisyblock.shape)

    # path of saved pkl file of model
    modelpath = 'checkpoints/DGU-3DMlab1_track1.pkl'
    expname = 'DGU-3DMlab1_track1'

    # set gpu
    device = torch.device('cuda:0')

    # make network object
    model = Generator_one2many_gd_rir_old(input_channel=1, numforrg=4, numofrdb=16, numofconv=8, numoffilters=67).to(device)

    # make numpy of output with same shape of input
    resultNP = np.ones(noisyblock.shape)
    print('resultNP.shape', resultNP.shape)

    submitpath = f'results_folder/{expname}'
    make_dirs(submitpath)

    # load checkpoint of the model
    checkpoint = torch.load(modelpath)
    model.load_state_dict(checkpoint['state_dict'])

    transform = ToTensor()
    revtransform = ToImage()

    # pass inputs through model and get outputs
    with torch.no_grad():
        model.eval()
        starttime = time.time()     # check when model starts to process
        for imgidx in tqdm.tqdm(range(noisyblock.shape[0])):
            for patchidx in range(noisyblock.shape[1]):
                img = noisyblock[imgidx][patchidx]   # img shape (256, 256, 3)

                input = transform(img).float()
                input = input.view(1, -1, input.shape[1], input.shape[2]).to(device)

                output = model(input)       # pass input through model

                outimg = revtransform(output)   # transform output tensor to numpy

                # put output patch into result numpy
                resultNP[imgidx][patchidx] = outimg

    # check time after finishing task for all input patches
    endtime = time.time()
    elapsedTime = endtime - starttime   # calculate elapsed time
    print('ended', elapsedTime)
    num_of_pixels = noisyblock.shape[0] * noisyblock.shape[1] * noisyblock.shape[2] * noisyblock.shape[3]
    print('number of pixels', num_of_pixels)
    runtime_per_mega_pixels = (num_of_pixels / 1000000) / elapsedTime
    print('Runtime per mega pixel', runtime_per_mega_pixels)

    # save result numpy as .mat file
    sio.savemat(f'{submitpath}/{expname}', dict([('results', resultNP)]))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM