“半監督”異常檢測方法GANomaly


原文標題:GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training

原文鏈接:

背景介紹

異常檢測是計算機視覺領域一個比較經典的問題,它旨在區分正常樣本(下文稱為OK樣本)和非正常樣本(下文稱為NG樣本)。乍一看,像是普通的二分類問題。其實不然,異常檢測有一個內在的屬性:樣本極其不平衡,即OK樣本非常多,NG樣本非常少。極端情況,訓練階段見不到任何NG樣本,該問題就變成了單分類問題了(本文也將這種只有OK樣本而沒有NG樣本參與訓練的情況稱為“半監督”,筆者認為是不妥的)。本文提出的GANomaly方法,就是針對這種極端情況的。

由於異常檢測問題中NG樣本通常比較少,直接學習能區分NG樣本的模型是很困難的。既然NG樣本不可靠,那大家自然會想到采取相反的思路,學習能區分OK樣本的模型就好,只要跟OK長得不像的就認為是NG的。自編碼器(Autoencoder)是異常檢測中比較經典的一種方法。它的解決思路是采用盡可能多的OK樣本去學習一個自編碼模型,由於該模型見過足夠多的OK樣本,因此它能夠很好地將OK樣本重建出來,而NG樣本它是沒見過的,因此它沒法很好地重建出來。推理階段,通過輸入圖片的重建誤差,就可以區分出OK和NG樣本了。但是,該方法非常容易受噪聲影響,需要在自編碼器上加各種約束,才能得到一個可用的異常檢測模型。

主要思想

 

 

如上圖所示,不同於一般的基於自編碼器的方法,本文采用的是一個編碼器(Encoder1)-解碼器(Decoder)-編碼器(Encoder2)的網絡結構,同時學習“原圖->重建圖”和“原圖的編碼->重建圖的編碼”兩個映射關系。該方法不僅對生成的圖片外觀(圖片->圖片)做了的約束,也對圖片內容(圖片編碼->圖片編碼)做了約束。另外,該方法還引入了生成對抗網絡(GAN)中的對抗訓練思想。這里,作者將Encoder1-Decoder-Encoder2當成生成網絡G-Net,又定義了一個判別網絡D-Net,通過交替訓練生成網絡和對抗網絡,最終學到一個比較好的生成網絡。

推理階段,該方法也不同於一般的基於自編碼器的異常檢測方法。最后用於推斷異常的不是原圖和重建圖的差異,而是第一部分編碼器產生的隱空間特征(原圖的編碼)和第二部分編碼器產生的隱空間特征(重建圖的編碼)的差異。這種方法更關注圖片實質內容的差異,對圖片中的微小變化不敏感,因而能解決自編碼器中易受噪聲影響的問題,魯棒性更好。

筆者認為本文的主要貢獻在於提出了這個Encoder1-Decoder-Encoder2的結構,D-Net只是錦上添花的。因為即便沒有D-Net和對抗訓練的思想,好好調參數該方法也可以work。

網絡結構

本文網絡結構包含三個子網絡。

第一個子網絡是一個常規的碗形的自編碼器,它的作用是用於重建輸入的OK圖像。該自編碼器結構的設計參考了DCGAN,具體而言,該自編碼器的解碼器部分(Decoder)和DCGAN的生成網絡幾乎是一樣的,即從一個n維的向量(bottleneck1)映射到一張3通道的圖片,如下圖所示。該自編碼器的編碼器部分(Encoder1)則是編碼器的逆過程,即從一張3通道的圖片映射到一個n維的向量。

第二個子網絡是一個編碼網絡(Encoder2),它的作用是將第一個子網絡重建出來的圖片再壓縮為一個n維的向量(bottleneck2)。雖然Encoder2采用的結構和Encoder1是一樣的,但它們的參數顯然是不一樣的。這么一個重復的結構看起來沒有什么了不起的,但筆者認為該結構是本文思想中最為核心的地方,它摒棄了絕大部分基於自編碼器的異常檢測方法常用的通過對比原圖和重建圖的差異來推斷異常的方式,采用了一種新的通過對比原圖和重建圖在高一層抽象空間中的差異來推斷異常的方式,而這一層額外的抽象可以使其大大提高抗噪聲干擾的能力,學到更加魯棒的異常檢測模型。

文章中第一個子網絡和第二個子網絡共同構成了生成對抗網絡中的生成網絡(G-Net),聽起來有點費解。其實可以換個思路想,第一個子網絡就是一個中規中矩的生成網絡,第二個子網絡只是它的一個約束條件而已。

第三個子網絡是一個判別網絡(D-Net),它的作用是用於區分原圖和重建圖(G-Net生成的圖片),即要將原圖判別為真,將重建圖判別為假。它的結構和第一個子網絡的解碼網絡是一樣的。D-Net的引入,是為了引入對抗訓練思想,旨在學到更好的G-Net。

綜上,該文章設計的網絡結構事實上比較簡單,就是一個Encoder和一個Decoder,只是通過不同的組合,生成了三部分的子網絡。接下來將介紹每部分子網絡采用的損失函數。

損失函數

本文包含三個子網絡,每個子網絡對應一個損失函數。由於文章中寫的損失函數和作者公布的代碼中的損失函數有些出入,筆者認為代碼中的損失函數更為合理,因此下文介紹的都是代碼中的損失函數。

第一個子網絡的損失是自編碼器的重建損失,這里借鑒了pix2pix文章中生成網絡的損失,采用的是L1損失,而不是L2損失。因為采用L2損失生成的圖像通常比采用L1生成的圖像要模糊。

[公式]

第二個子網絡的損失是編碼網絡的損失,這里需要比對的是原圖和重建圖在高一層抽象空間中的差異,即兩個bottleneck(上文中的bottleneck1和bottleneck2)間的差異,采用的是L2損失。

[公式]

第三個子網絡的損失是常規的GAN中判別網絡的損失,這里采用的是二分類的交叉熵損失。

[公式]

正常來說,采用第一個子網絡的生成損失和第三個子網絡的判別損失就能生成比較不錯的圖片了,但是這篇文章主要解決的是異常檢測問題。異常是圖片集的特性,采用像素級的損失(原圖和重建圖的差異)來推斷是不夠合理的,因而引入了第二個子網絡的編碼損失,文章中最后用於推斷的也是該損失。

訓練

本文采用的訓練策略和常規的GAN一樣的,即交替地優化D-Net和G-Net。

優化D-Net時,采用的損失為上述第三個子網絡的損失,即:

[公式]

這里的輸入 [公式] 。雖然這里的 [公式] 需要通過G-Net來生成,但是訓練D-Net時,G-Net的參數是固定的。

優化G-Net時,采用的損失比較復雜:

[公式]

主體損失為重建損失 [公式] ,編碼損失[公式]為重建損失的一個約束,對抗損失[公式]則是用來和D-Net博弈。需要注意的一點是,這里的對抗損失的輸入對象和優化D-Net時的輸入對象是不一樣的,這里的 [公式] ,這和常規GAN的訓練是一致的。

推斷

前面提到,本文采用的推斷方式和一般的基於自編碼器的異常檢測方法是不一樣的。這里推斷以來的不是重建損失[公式],而是編碼損失[公式]。具體而言,網絡訓練收斂以后,我們可以計算得到所有OK樣本中的[公式]值,選取其中最大的作為判別閾值。推斷時,給定一張圖片,我們可以利用學好的網絡,計算其 [公式] 值,如果它小於判別閾值則判斷為OK樣本(正常樣本),大於則判斷為NG樣本(異常樣本)。

實驗

要做基於GANomaly的異常檢測實驗,需要准備大量的OK樣本和少量的NG樣本。找不到合適的數據集怎么辦?很簡單,隨便找個開源的分類數據集,將其中一個類別的樣本當作異常類別,其他所有類別的樣本當作正常樣本即可,文章中的實驗就是這么干的。具體試驗結果如下:

反正在效果上,GANomaly是超過了之前兩種代表性的方法。此外,作者還做了性能對比的實驗。事實上前面已經介紹了GANomaly的推斷方法,就是一個簡單的前向傳播和一個對比閾值的過程,因此速度非常快。具體結果如下:

可以看出,計算性能上,GANomaly表現也是非常不錯的。

總結

雖然異常檢測在數據挖掘領域很早就有人做了,但是計算機視覺領域的相關研究還相對較少。另外,GAN這幾年非常火,GAN到底能不能做異常檢測,還沒有太多人嘗試過。本文算是一個比較成功地將GAN用到異常檢測的例子。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM