github代碼地址:https://github.com/mrlibw/ControlGAN
關鍵詞:T2I,文本生成圖像,ControlGAN
Introduction:
現在的許多模型如果改變了輸入文本的其中一個部分,那么輸出的圖片會與原來文本生成的圖片大相徑庭,沒法實現一部分的修改。如下圖所示。

controlGAN,由三個部分組成:
1.word-level spatial and channel-wise attention-drive generator,采用了attention機制,多層次結構。
2.word-level discriminator,研究詞與圖像子區域的關系,來區分不同視覺屬性。
3.perceptual loss,通過減少生成過程中的隨機性,強制generator保持與修改文本無關的部分。
Controllable Generative Adversarial Network
給定一段文本S,目標是合成一張與S語義相關的圖像I',同時使生成過程可控,即當S修改為Sm時,合成結果I''應與Sm語義相關,同時保留與修改的文本無關的內容。
模型結構。
ControlGAN基於multi-stage AttnGAN。
給定文本S,輸入到text encoder(一個預訓練的雙向RNN),得到文本特征s∈RD,和w屬於RDxL,s有D維,words數量L。
對s做conditioning augumentation(CA),得到增強后的文本特征s'。生成一個隨機變量z,s'和z連接到一起作為輸入送到stage I。
整個模型逐階段生成從粗糙到精細的圖片,對於每個階段,網絡輸出一個隱藏的可視特征vi,vi是相對應的generator Gi的輸入。
spatial attention和channel-wise attention會將w和vi作為輸入,輸出attentive word-context feature。這個特征會vi contact一起,作為下一階段的輸入。
spatial attention只將word與單個空間位置關聯起來,不考慮channel信息。
論文新提出的channel-wise attention考慮word與channel的關聯。
實驗發現channel-wise attention與對應詞的語義信息關聯,而spatial attention與顏色相關,因此該結構可以用來區分不同的視覺屬性。
Channel-Wise Attention,結構如圖所示。

在第k層,輸入word特征w∈RDxL和視覺特征vk∈RCx(Hk*Wk),Hk和Wk分別代表第k層特征圖的高和寬。
w通過一個perception layer Fk被映射到與vk相同的語義空間,即w'k=Fkw,Fk∈R(Hk*Wk)xD.
記channel-wise attention矩陣為mk∈RCxL,mk=vk * w'k,從而mk聚集了所有空間位置的channel和words的聯系信息。接着,使用softmax函數對mk進行歸一化,得到αk,如下圖。

attention weight αki,j 代表vk的第i個channel和文本S中的第j個單詞之間的關系,越大代表關聯越近。
最后,fαk=αk*(w'k)T, 其中 fαk∈RCx(Hk*Wk)。
fαk中蘊含了每個channel和word的關系,因此具有更高關聯值的channel在生成過程中會被增強,從而將生成過程中的每個channel給分開,並且降低無關的channel帶來的影響。
Word-level Discriminator,如圖所示。

為了讓generator只修改部分圖像內容,discriminator應向generator提供詳細的訓練數據。
輸入word特征w和w',w和w'∈RDxL,其中,w根據原始文本S編碼得到,w'是從一個隨機采樣的不匹配文本中編碼得到。視覺特征nreal和nfake,由基於GoogleNet的圖片encoder得到,它們分別有real image I和生成的image I'得到。
為了簡單起見,使用n∈RCx(H*w)來代表nreal和nfake。使用w屬於RDxL來代表兩個文本特征w和w'。
word-level discriminator包含一個perception layer F',它用於對准n和w的channel維度,即得到 n'=F' * n,其中F'∈RDxC,是一個待學習的權重矩陣。
接着,計算word-context關聯矩陣m=wT * n',其中m∈RLx(H*W)。然后使用softmax函數進行歸一化得到關聯矩陣β。

其中, βi,j代表第i個word和第j個圖像子區域只見的關聯值。 然后計算感知圖像子區域的word特征b,b=n' * βT,b∈RDxL。b包含了所有空間信息。
此外,通過一個word-level的self-attention得到一維向量γ,長度L代表每個單詞的相對重要性。重復γ D次得到γ',γ'屬於∈RDxL。
計算b'=b⊙γ',⊙代表element-wise的乘積,即b'i,j為bi,j*γ'i,j.
最后根據如下公式得到第i個單詞和整副圖片的關聯。

σ是sigmoid函數。
最后,計算Image和Sentence的最終關聯Lcorre,Lcorre=Σi=0L-1ri.將其反饋給generator就可以進一步幫助修改每一個子區域。
Perceptual Loss
由於沒有在於文本無關的圖像區域施加限制,生成的圖片可能有高度隨機性,也可能會和其他內容語義不相關。為了減少隨機性,本論文引入了基於16-layer VGG network的perceptual loss,該模型在ImageNet數據集上預訓練過。該網絡結構用於從生成的圖片I'和真實的圖片I中提取語義特征,定義如下:

其中Φi(I)代表VGG的第i層的activation。
目標函數
generator和discriminator是通過交替訓練來降低generator loss LG和discriminator loss LD
generator loss LG包括對抗損失LGk,感知損失Lper,文本和圖片關聯度損失Lcorre和基於余弦相似度的文本圖片匹配損失LDAMSM。

K是stage數,Ik是從真實數據分布Pdata中采樣得到的圖片的第k個stage。I'k是從模型分布PGk中采樣得到的。
三個λ是超參數。
LGk由非條件對抗損失和條件對抗損失構成,其中非條件對抗損失用於確保圖片的真實度,條件對抗損失保證文本和圖片匹配。

discriminator loss.

其中,Lcorre代表與單詞相關的區域是否存在,S'是從文本數據集中隨機采樣的與Ik不匹配的句子。
對抗損失LDk與generator相同。

實驗
數據集
基於CUB和COCO。CUB包含8855訓練圖片和2933測試圖片,每張圖片有10個對應文本。COCO包含82783訓練圖片和40504驗證圖片,每張圖片有5個對應文本。使用StackGAN中介紹的方法對其進行數據預處理。
實現
其他:
Image Caption,即 Image-To-Text。
