條件GAN(Conditional Generative Adversarial Nets),原文地址為CGAN。
Abstract
生成對抗網絡(GAN)是最近提出的訓練生成模型(generative model)的新方法。在本文中,我們介紹了條件GAN(下文統一簡稱為CGAN),簡單來說我們把希望作為條件的data y同時送入generator和discriminator。我們在文中展示了在數字類別作為條件的情況下,CGAN可以生成指定的MNIST手寫數字。我們同樣展示了CGAN可以用來學習多形態模型(multi-modal model),我們提供了一個image-tagging的應用,其中我們展示了這種模型可以產生豐富的tags,這些tags並不是訓練標簽的一部分。
1. Introduction
GAN最近被提出作為訓練生成模型的替代框架,用來規避有些情況下近似復雜概率的計算的困難。 GAN的一個重要優勢就是不需要計算馬爾科夫鏈(Markov chains),只需要通過反向傳播算法計算梯度,在學習過程中不需要進行推斷(inference),一系列的factors和interactions可以被輕易地加入到model當中。 更進一步地,就像[8]中顯示的那樣,CGAN可以產生state-of-the-art的對數似然估計(log-likelihood)和十分逼真的樣本。 在非條件的生成模型中,我們沒法控制生成什么樣模式的樣本。然而,通過給model增加額外的信息,我們可以引導模型生成數據的方向。這樣的條件可以建立在類別標簽,或者[5]展示的圖像修復的部分數據,甚至可以是不同模式的數據上。 本文展示了應該如何構建CGAN。我們展示了CGAN在兩個數據集上的結果,一個是以類別標簽作為條件的MNIST數據集;還有一個是建立在MIR Flickr 25,000 dataset上的多模態學習(multi-modal learning)。
2. Related Work
2.1 圖像標簽的多模態學習
盡管最近監督神經網絡(特別是卷積網絡)取得了巨大的成功,但是將這種模型擴展到有非常大的預測輸出類別的問題上仍然面臨着巨大的挑戰。第二個問題是,當今的大部分工作都主要集中在學習輸入到輸出的一對一的映射。然而很多有趣的問題可以考慮為概率上的一對多的映射。比如說在圖片標注問題上,對於一個給定的圖片可能對應了多個標簽,不同的人類標注者可能會使用不同的(但通常是相似的或者是相關的)詞匯來描述相同的一幅圖片。 解決第一個問題的一種方式是從其他的模式中施加額外的信息,比如說通過語言模型來學習詞匯的向量形式的表達,其中幾何上的關系對應了語義上的相關。在這樣的空間(映射之后的向量空間)做預測時,一個很好的性質時,即使我們的預測錯誤了,但是仍然和真實的答案很接近(比如說預測是"table"而不是"chair"),還有一個優勢是,我們可以自然地對即使在訓練時沒有見過的詞匯做generalizations prediction,因為相似的向量語義上也是相似的。[3]的工作顯示即使是一個從圖像特征空間到單詞表達空間(word vector)的線性映射都可以提高分類的性能。 解決第二個問題的一種解決辦法是使用條件概率生成模型,輸入是作為條件變量,一對多的映射被實例化為一個條件預測分布。 [16]對第二個問題采用了和我們類似的辦法,他們在MIR Flickr 25,000 dataset上訓練了一個深度玻茲曼機。 除此之外,[12]的作者展示了如何訓練一個有監督的多模態自然語言模型,這樣可以為圖片生成描述的句子。
3. Conditional Adersarial Nets(條件對抗網絡)
3.1 Generative Adervasarial Nets
GAN是最近提出的一種新穎的訓練生成模型的方式。它包含了兩個“對抗”模型:生成模型G捕獲數據分布,判別模型D估計樣本來自訓練數據而不是G的概率。G和D都可以是非線性的映射函數,比如多層感知機模型。 為了學習生成器關於data x的分布$p_g$,生成器構建了一個從先驗噪聲分布$p_z(z)$到數據空間的映射$G(z;\theta_g)$。判別器$D(x;\theta_d)$輸出了一個單一的標量,代表x來自訓練樣本而不是$p_g$的概率。 G和D是同時訓練的:我們調整G的參數來最小化$log(1-D(g(Z)))$,然后調整D的參數來最小化$log(D(X))$,他們就像如下的兩人的最小最大化博弈(two player min-max game),價值函數(value function)為$V(G,D)$:
3.2 Conditional Adersarial Nets
如果生成器和判別器都基於一些額外的信息y的話,GAN可以擴展為一個條件模型。y可以是任何形式的輔助信息,比如說類別標簽或者其他模式的數據。我們可以通過增加額外的輸入層來將y同時輸入生成器和判別器,來實施條件模型。 在生成器中,先驗的噪聲輸入$p_z(z)$和y被結合成一個連接隱藏表達(joint hidden representation),對抗訓練的框架為組成隱藏表達(compose of hidden representation)提供了相當大的靈活性。 在判別器中,x和y被作為輸入送入判別函數(再一次地,比如可以是一個MLP,多層感知器)。 Two player minimax game的目標函數如公式(2):
圖1展示了一個簡單的條件對抗網絡的架構。
圖1 條件對抗網絡
4. Experimental Results
4.1 Unimodal(單一模式)
我們以類別標簽作為條件在MNIST數據集上訓練了一個對抗網絡,類別標簽是作為one-hot vectors的形式。 在生成網絡中,100維的噪聲先驗分布是從unit hypercube(單位超方體)的均勻分布采樣得到的。z和y都是映射到帶有relu激活函數的hidden layers,隱藏層節點數分別為200和1000,然后二者的輸出相結合形成一個節點數為1200的帶有relu激活函數的hidden layer,最后是一個sigmoid unit hidden layer作為輸出,生成784維的MNIST samples。
判別器將x映射到一個有240 units and 5 pieces的maxout layer[6],y映射到一個有50 units and 5 pieces 的maxout layer。這兩個hidden layers在被送入sigmoid layer之前都被映射到一個有240 units and 4 pieces 的joint maxout layer。(判別器的准確的架構不是特別重要,只要有sufficient power即可;我們發現對於這個任務maxout units非常合適)。 模型的訓練使用SGD,mini-batch size 為100,初始化的學習率為0.1,指數衰減因子為1.00004,最終的學習率為0.000001。momentum參數初始化為0.5,最終增加到0.7。generator和discriminator都需要使用dropout,dropout rate為0.5。在validation set 上的最佳對數似然估計被作為停止點(early stop)。 表1顯示了對於MNIST的test data的Gaussian Parzen window對數似然估計。從10個類別的每一個類別采樣共得到1000個samples,然后使用Gaussian Parzen window來擬合這些samples。然后我們使用Parzen window 分布來估計測試集的對數似然。([8]詳細介紹了怎么做這種估計)。 條件對抗網絡的結果顯示了,我們的實驗結果和基於其他網絡得到的結果相近,但是比其中的幾種方法更加優越——包括非條件對抗網絡。我們展示這種優越性更多是基於概念上的,而不是具體的功效,我們相信,未來如果對超參數和模型架構進行更深入的探索,條件模型可以達到甚至超過非條件模型的結果。 圖2顯示了一些生成的樣本,每一行是基於一個label生成的樣本,而每一列則代表了生成的不同樣本。
圖2 生成的MNIST手寫數字,每一行是以一個label作為條件
4.2 Multimodal(多模態)
像Flickr這樣的圖像網站,是圖像以及用戶為圖像生成的額外信息(user-generated metadata,UGM)的有標記數據的豐富來源——特別數用戶提供的標簽。 用戶提供的標記信息與經典的圖像標簽不一樣的地方在於,用戶提供的標記信息內容更加豐富,語義上也更加接近人類用自然語言對於圖像的描述,而不僅僅是識別出圖像中有什么東西。UGM中同義詞很普遍,不同的用戶對相同的圖像內容可能用不同的詞匯去描述,因此,找到一種對這些標簽進行標准化的有效方式是非常重要的。概念上的詞向量是非常有用的,因為表達成為詞向量之后,語義相近的詞向量在距離上也是相近的。在本節當中,我們展示了圖像的自動標記,可以帶有多個預測標簽,我們基於圖像特征使用條件對抗網絡生成(可能是多模態)標簽向量的分布。 對於圖像特征來說,我們采用和[13]類似的方法,在帶有21000個標簽的全部ImageNet數據集上預訓練了一個卷積網絡。我們使用了卷積網絡最后一層帶有4096個units的全連接層作為圖像的特征表達。 對於單詞表達來說,我們從[YFCC100M](http://webscope.sandbox.yahoo.com/catalog)數據集獲取了用戶標簽,標題以及圖像描述的語料庫。在對文本進行預處理以及清洗之后,我們訓練了一個skip-gram model,word vector的size是200。我們從詞典當中丟棄了出現次數少於200次的詞匯。最后詞典的大小是247465。 我們在訓練對抗網絡過程中保持卷積網絡和語言模型(language model)固定。未來我們將會探索,將反向傳播同時應用於對抗網絡,卷積網絡和語言模型。 在實驗的過程中,我們使用了MIR Flickr 25,000 數據集,並且使用了如上所述的卷積網絡和語言模型提取了圖像特征和標簽(詞向量)特征。沒有任何標簽的圖像被我們舍棄了,注釋被看做是額外的標簽。前15萬的樣例被作為訓練樣本。有多個標簽的圖像,帶有每一個標簽的圖像分別被看做一組數據。 評估過程,對於每個圖像我們生成了100個samples,並且使用余弦距離找出了最相近的20個單詞。然后我們選取了100個samples中最常出現的10個單詞。表4.2展示了一些用戶關聯生成的標簽和注釋以及生成的標簽。 表現最佳的條件對抗網絡的生成器接收size為100的高斯噪聲作為先驗噪聲,然后將它映射到500維的relu層,然后將4096層的圖像特征映射到2000維的relu layer,這些層都被映射到一個200維的線性layer的然后連接表達,最后輸出生成的詞向量。 鑒別器由對於詞向量500維的relu layer,對圖像特征1200維的relu layer組成,然后是一個帶有1000個units和3pieces的maxout layer,最后送入sigmoid單元得到輸出。 模型的訓練使用了隨機梯度下降(SGD),batch size =100,初始學習率為0.1,指數衰減率為1.00004,最后學習率下降到0.000001。同時模型也使用了momentum(動量加速),初始值為0.5,最后上升到0.7。生成器和鑒別器都使用了dropout,dropout rate 為0.5。 超參數以及模型架構由交叉驗證還有混合了手工以及的網格搜索的方法所得到。
5. Feature Work
本文顯示的結果非常初步,但是它展示了條件對抗網絡的潛力,同時也為有趣且有用的應用提供了新的思路。在未來進一步的探索當中,我們希望展示更加豐富的模型以及對於模型表現、特性更加具體深入的分析。同時在當前的實驗中,我們僅僅使用了每個單獨的標簽,我們希望可以通過一次使用多個標簽取得更好的結果。 另外一個顯然未來可以探索的方向是我們可以將對抗網絡和語言模型結合到一起訓練。[12]的工作顯示了我們可以學習到針對特定任務的語言模型。
References
[1] Bengio, Y., Mesnil, G., Dauphin, Y., and Rifai, S. (2013). Better mixing via deep representations. In ICML’2013. [2] Bengio, Y., Thibodeau-Laufer, E., Alain, G., and Yosinski, J. (2014). Deep generative stochastic networks trainable by backprop. In Proceedings of the 30th International Conference on Machine Learning (ICML’14). [3] Frome, A., Corrado, G. S., Shlens, J., Bengio, S., Dean, J., Mikolov, T., et al. (2013). Devise: A deep visual-semantic embedding model. In Advances in Neural Information Processing Systems, pages 2121–2129. [4] Glorot, X., Bordes, A., and Bengio, Y. (2011). Deep sparse rectifier neural networks. In International Conference on Artificial Intelligence and Statistics, pages 315–323. [5] Goodfellow, I., Mirza, M., Courville, A., and Bengio, Y. (2013a). Multi-prediction deep boltzmann machines. In Advances in Neural Information Processing Systems, pages 548–556. [6] Goodfellow, I. J., Warde-Farley, D., Mirza, M., Courville, A., and Bengio, Y. (2013b). Maxout networks. In ICML’2013. [7] Goodfellow, I. J., Warde-Farley, D., Lamblin, P., Dumoulin, V., Mirza, M., Pascanu, R., Bergstra, J., Bastien, F., and Bengio, Y. (2013c). Pylearn2: a machine learning research library. arXiv preprint arXiv:1308.4214. [8] Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. (2014). Generative adversarial nets. In NIPS’2014. [9] Hinton, G. E., Srivastava, N., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. (2012). Improving neural networks by preventing co-adaptation of feature detectors. Technical report, arXiv:1207.0580. [10] Huiskes, M. J. and Lew, M. S. (2008). The mir flickr retrieval evaluation. In MIR ’08: Proceedings of the 2008 ACM International Conference on Multimedia Information Retrieval, New York, NY, USA. ACM. [11] Jarrett, K., Kavukcuoglu, K., Ranzato, M., and LeCun, Y. (2009). What is the best multi-stage architecture for object recognition? In ICCV’09. [12] Kiros, R., Zemel, R., and Salakhutdinov, R. (2013). Multimodal neural language models. In Proc. NIPS Deep Learning Workshop. [13] Krizhevsky, A., Sutskever, I., and Hinton, G. (2012). ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems 25 (NIPS’2012). [14] Mikolov, T., Chen, K., Corrado, G., and Dean, J. (2013). Efficient estimation of word representations in vector space. In International Conference on Learning Representations: Workshops Track. [15] Russakovsky, O. and Fei-Fei, L. (2010). Attribute learning in large-scale datasets. In European Conference of Computer Vision (ECCV), International Workshop on Parts and Attributes, Crete, Greece. [16] Srivastava, N. and Salakhutdinov, R. (2012). Multimodal learning with deep boltzmann machines. In NIPS’2012. [17] Szegedy, C., Liu, W., Jia, Y., Sermanet, P., Reed, S., Anguelov, D., Erhan, D., Vanhoucke, V., and Rabinovich, A. (2014). Going deeper with convolutions. arXiv preprint arXiv:1409.4842.