StarGAN的引入是為了解決多領域間的轉換問題的,之前的CycleGAN等只能解決兩個領域之間的轉換,那么對於含有C個領域轉換而言,需要學習C*(C-1)個模型,但StarGAN僅需要學習一個,而且效果很棒,如下:

創新點:為了實現可轉換到多個領域,StarGAN加入了一個域的控制信息,類似於CGAN的形式。在網絡結構設計上,鑒別器不僅僅需要學習鑒別樣本是否真實,還需要對真實圖片判斷來自哪個域。

整個網絡的處理流程如下:
- 將輸入圖片x和目標生成域c結合喂入到生成網絡G來合成fake圖片
- 將fake圖片和真實圖片分別喂入到鑒別器D,D需要判斷圖片是否真實,還需要判斷它來自哪個域
- 與CycleGAN類似,還有一個一致性約束,將生成的fake圖片和原始圖片的域信息c'結合起來喂入到生成器G要求能輸出重建出原始輸入圖片x
下面分析一下各個部分的損失函數:
一:GAN常見的對抗損失:

二:對於給定的輸入圖片x和目標域標簽c,網絡的目標是將x轉換成輸出圖片y,輸出圖片y能夠被歸類成目標域c。為了實現這一點就需要鑒別器有判別域的功能。所以作者在D的頂端加了一個額外的域分類器,域分類器loss在優化D和G時都會用到,作者將這一損失分為兩個方向,分別用來優化G和D。(這很容易理解,因為如下分析可以看到公式(3)沒有辦法為D提供訓練需要的監督信息)
一個是真實圖片的域分類損失用來優化D,另一個是fake圖片的域分類損失來優化G。
1)![]()
Dcls(c'|x)代表D對真實圖片計算得到的域標簽概率分布。這一學習目標將會使得D能夠將輸入圖片x識別為對應的域c',這里的(x,c')是訓練集給定的。
2)
fake圖片的域分類的損失函數定義如(3),它用來優化G,也就是讓G盡力去生成圖片讓它能夠被D分類成目標域c。
三:還有一個重建損失
通過最小化對抗損失與分類損失,G努力嘗試做到生成目標域中的現實圖片。但是這無法保證學習到的轉換只會改變輸入圖片的域相關的信息而不改變圖片內容。所以加上了周期一致性損失:
這里就是將G(x,c)和圖片x的原始標簽c'結合喂入到G中,將生成的圖片和x計算1范數差異。
總體損失:
在實際操作上,作者將對抗損失換成了WGAN的對抗損失:

以上對於單個數據集的訓練來說已經足夠了,但是現在想想另一個問題,假如我要聯合訓練多個數據集呢?
舉例來說,celebA和RaFD數據集,前者有發色和性別信息,后者有面部表情信息,我能將celebA中的人物改變一下面部表情嗎?
一個很簡單的想法是如果我原來的域標注信息是5位的onehot編碼,現在變長為8位不就可以了。但是這存在一個問題就是celebA中的人其實也有表情,只是沒有標注,RaFD其實也有性別區別,但對於網絡來說沒標記就是未知的。簡單擴充域標記信息位是肯定不行的。我們希望網絡只關注它有明確信息的那一部分標注。
因此,作者加了一個mask。在聯合多個數據集訓練時把mask向量也輸入到生成器。
以上的ci代表第i個數據集的標簽,已知標簽ci如果是二進制屬性則可以表示為二進制向量,如果為類別屬性表示一個onehot。剩下的n-1個則指定為0。m則是一個長度為n的onehot編碼。這樣網絡就會只關注已給定的標簽。
論文部分到此結束,下面來分析一下代碼:
主要的代碼有model.py和solver.py兩個。
在model.py中作者創建了生成器G與鑒別器D。
在生成器中先對模型降維縮小為原來4倍,再使用多個殘差網絡獲得等維度輸出,接着使用轉置卷積放大4倍,最后通過一層尺寸不變的卷積,取tanh作為輸出。

另外一個值得注意的是生成器如何將輸入圖片與目標域c一起結合作為輸入的,代碼中可以看出就是直接在第四維度上進行拼接(pytorch一般為N*C*H*W,所以看起來是在第二維)。
對於鑒別器,使用conv1的輸出代表域的預測概率,conv2的輸出代表圖片是否為真的判斷。這兩個的關系是並行的。

Solver.py比較長,挑選重要的部分來解釋:
首先是梯度懲罰,這一部分來自WGAN的改善工作,主要是為了滿足Lipschitz連續這個WGAN推導中需要的數學約束。

令人疑惑的是分類loss並不都是交叉熵損失,這是因為CelebA的標簽是多屬性的,不是一個onehot,所以使用了一個多個二分類的形式,而RaFD則是一個onehot。
下面來看看在多個訓練集訓練時代碼上是怎么操作的。
在數據加載上其實還是單個數據集輪流進行操作的,如下:

以上提到在多數據集訓練時,我們需要mask向量,mask向量的形成按如下形式進行拼接,前面是celebA的label后面是RaFD的label,最后是onehot,代表了哪個數據集的標簽是已知的。

以生成器為例,計算損失時也是只在輸出判斷向量中提取該數據集已知的部分進行loss計算。

