【Learning Notes】變分自編碼器(Variational Auto-Encoder,VAE)


轉載自http://blog.csdn.net/jackytintin/article/details/53641885

近年,隨着有監督學習的低枝果實被采摘的所剩無幾,無監督學習成為了研究熱點。VAE(Variational Auto-Encoder,變分自編碼器)[1,2] 和 GAN(Generative Adversarial Networks) 等模型,受到越來越多的關注。

筆者最近也在學習 VAE 的知識(從深度學習角度)。首先,作為工程師,我想要正確的實現 VAE 算法,以及了解 VAE 能夠幫助我們解決什么實際問題;作為人工智能從業者,我同時希望在一定程度上了解背后的原理。

作為學習筆記,本文按照由簡到繁的順序,首先介紹 VAE 的具體算法實現;然后,再從直觀上解釋 VAE 的原理;最后,對 VAE 的數學原理進行回顧。我們會在適當的地方,對變分、自編碼、無監督、生成模型等概念進行介紹。

我們會看到,同許多機器算法一樣,VAE 背后的數學比較復雜,然而,工程實現上卻非常簡單。

這篇 Conditional Variational Autoencoders 也是 by intuition 地介紹 VAE,幾張圖也非常用助於理解。

1. 算法實現

這里介紹 VAE 的一個比較簡單的實現,盡量與文章[1] Section 3 的實驗設置保持一致。完整代碼可以參見 repo

1.1 輸入:

數據集 XRn

做為例子,可以設想 X 為 MNIST 數據集。因此,我們有六萬張 0~9 的手寫體 的灰度圖(訓練集), 大小為 28×28。進一步,將每個像素歸一化到[0,1],則 X[0,1]784 。

MNIST 
圖1. MNIST demo (圖片來源)

1.2 輸出:

一個輸入為 m 維,輸出為 n 維的神經網絡,不妨稱之為 decoder [1](或稱 generative model [2])(圖2)。

decoder 
圖 2. decoder

  • 在輸入輸出維度滿足要求的前提下,decoder 以為任何結構——MLP、CNN,RNN 或其他。
  • 由於我們已經將輸入數據規一化到 [0, 1] 區間,因此,我們令 decoder 的輸出也在這個范圍內。這可以通過在 decoder 的最后一層加上 sigmoid 激活實現 : 
    f(x)=11+ex
  • 作為例子,我們取 m = 100,decoder 的為最普遍的全連接網絡(MLP)。基於 Keras Functional API 的定義如下:
n, m = 784, 2 hidden_dim = 256 batch_size = 100 ## Encoder z = Input(batch_shape=(batch_size, m)) h_decoded = Dense(hidden_dim, activation='tanh')(z) x_hat = Dense(n, activation='sigmoid')(h_decoded)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

1.3 訓練

VAE overview
圖 3. VAE 結構框架

1.3.1 encoder

為了訓練 decoder,我們需要一個輔助的 encoder 網絡(又稱 recognition model)(如圖3)。encoder 的輸入為 n 維,輸出為 2×m 維。同 decoder 一樣,encoder 可以為任意結構。

encoder 
圖 4. encoder

1.3.2 采樣(sampling)

我們將 encoder 的輸出(2×m 個數)視作分別為 m 個高斯分布的均值(z_mean)和方差的對數(z_log_var)。

接着上面的例子,encoder 的定義如下:

## Encoder x = Input(batch_shape=(batch_size, n)) h_encoded = Dense(hidden_dim, activation='tanh')(x) z_mean = Dense(m)(h_encoded) # 均值 z_log_var = Dense(m)(h_encoded) # 方差對數
  • 1
  • 2
  • 3
  • 4
  • 5

然后,根據 encoder 輸出的均值與方差,生成服從相應高斯分布的隨機數:

epsilon = K.random_normal(shape=(batch_size, m), 
                          mean=0.,std=epsilon_std) # 標准高斯分布 z = z_mean + exp(z_log_var / 2) * epsilon
  • 1
  • 2
  • 3

z 就可以作為上面定義的 decoder 的輸入,進而產生 n 維的輸出 x^

sampler
圖5. 采樣

這里運用了 reparemerization 的技巧。由於 zN(μ,σ),我們應該從 N(μ,σ) 采樣,但這個采樣操作對 μ 和 σ 是不可導的,導致常規的通過誤差反傳的梯度下降法(GD)不能使用。通過 reparemerization,我們首先從 N(0,1) 上采樣 ϵ,然后,z=σϵ+μ。這樣,zN(μ,σ),而且,從 encoder 輸出到 z,只涉及線性操作,(ϵ 對神經網絡而言只是常數),因此,可以正常使用 GD 進行優化。方法正確性證明見[1] 2.3小節和[2] 第3節 (stochastic backpropagation)。

reparameterization 
圖6. Reparameterization (圖片來源)

preparameterization 的代價是隱變量必須連續變量[7]。

1.3.3 優化目標

encoder 和 decoder 組合在一起,我們能夠對每個 xX,輸出一個相同維度的 x^。我們目標是,令 x^ 與 x 自身盡量的接近。即 x 經過編碼(encode)后,能夠通過解碼(decode)盡可能多的恢復出原來的信息。

注:嚴格而言,按照模型的假設,我們要優化的並不是 x 與 x^ 之間的距離,而是要最大化 x 的似然。不同的損失函數,對應着不是 p(x|z) 的不同概率分布假設。此處為了直觀,姑且這么解釋,詳細討論見下文([1] 附錄C)。

由於 x[0,1],因此,我們用交叉熵(cross entropy)度量 x 與 x^ 差異:

 

xent=i=1n[xilog(x^i)+(1xi)log(1x^i)]

 

xent 越小,x 與 x^ 越接近。

我們也可以用均方誤差來度量: 

mse=i=1n(xix^i)2


mse 越小,兩者越接近。

 

訓練過程中,輸出即是輸入,這便是 VAE 中 AE(autoencoder,自編碼)的含義。

另外,我們需要對 encoder 的輸出 z_mean(μ)及 z_log_var(logσ2)加以約束。這里使用的是 KL 散度(具體公式推導見下文): 

KL=0.5(1+logσ2μ2σ2)=0.5(1+logσ2μ2exp(logσ2))

 

這里的KL, 其實是 KL 散度的負值,見下文。

總的優化目標(最小化)為:

 

loss=xent+KL

 

 

loss=mse+KL

 

綜上所述,有了目標函數,並且從輸入到輸出的所有運算都可導,我們就可以通過 SGD 或其改進方法來訓練這個網絡了。

由於訓練過程只用到 x(同時作為輸入和目標輸出),而與 x 的標簽無關,因此,這是無監督學習。

1.4 小結

總結一下,如圖2,VAE 包括 encoder (模塊 1)和 decoder(模塊 4) 兩個神經網絡。兩者通過模塊 2、3 連接成一個大網絡。得益於 reparemeterization 技巧,我們可以使用常規的 SGD 來訓練網絡。

學習算法的最好方式還是讀代碼,網上有許多基於不同框架的 VAE 參考實現,如 tensorflowtheanokerastorch

2. 直觀解釋

2.1 VAE 有什么用?

2.1.1 數據生成

由於我們指定 p(z) 標准正態分布,再接合已經訓練和的 decoder (p(x|z)),就可以進行采樣,生成類似但不同於訓練集數據的新樣本。 
這里寫圖片描述 
圖7. 生成新的樣本

圖8(交叉熵)和圖9(均方誤差)是基於訓練出來的 decoder,采樣生成的圖像(x^

x_xent
圖8. 交叉熵損失

x_mse
圖9. 均方誤差損失

嚴格來說,生成上圖兩幅圖的代碼並不是采樣,而是 E[x|z] 。伯努力分布和高斯分布的期望,正好是 decocder 的輸出 x^。見下面的討論。

2.1.2 高維數據可視化

encoder 可以將數據 x,映射到更低維的 z 空間,如果是2維或3維,就可以直觀的展示出來(圖10、11)。

z_xent
圖10. 交叉熵損失

z_mse
圖11. 均方誤差損失

2.1.3 缺失數據填補(imputation)

對許多現實問題,樣本點的各維數據存在相關性。因此,在部分維度缺失或不准確的情況,有可能通過相關信息得到填補。圖12、13展示一個簡單的數據填補的實例。其中,第一行為原圖,第二行為中間某幾行像素的缺失圖,第三行為利用 VAE 模型恢復的圖。

i_xent
圖12. 交叉熵損失

i_mse
圖13. 均方誤差損失

2.1.4 半監督學習

相比於高成本的有標注的數據,無標注數據更容易獲取。半監督學習試圖只用一小部分有標注的數據加上大量無標注數據,來學習到一個較好預測模型(分類或回歸)。 
VAE 是無監督的,而且也可以學習到較好的特征表征,因此,可以被用來作無監督學習[3, 12]。

2.2 VAE 原理

由於對概率圖模型和統計學等背景知識不甚了了,初讀[1, 2],對問題陳述、相關工作和動機完全沒有頭緒。因此,先放下公式,回到 comfort zone,類比熟悉的模型,在直覺上理解 VAE 的工作原理。

2.2.1 模型結構

從模型結構(以及名字)上看,VAE 和 自編碼器(audoencoder)非常的像。特別的,VAE 和 CAE(constractive AE)非常相似,兩者都對隱層輸出增加長約束。而 VAE 在隱層的采樣過程,起到和 dropout 類似的正則化偷用。因此,VAE 應該和 CAE 有類似的訓練和工作方式,並且不太易容過擬合。

2.2.2 流形學習

數據雖然高維,但相似數據可能分布在高維空間的某個流形上(例如圖14)。而特征學習就要顯式或隱式地學習到這種流形。

manifold 
圖14. 流形學習 (圖片來源)

正是這種流形分布,我們才能從低的隱變量恢復出高維的觀測變量。如圖8、圖9,相似的隱變量對應的觀測變量確實比較像,並且這樣相似性是平滑的變化。

3. 推導

VAE 提出背景涉及概率領域的最大似然估計(最大后驗概率估計)、期望最大化(EM)算法、變分推理(variational inference,VI)、KL 散度,MCMC 等知識。但 VAE 算法本身的數學推導不復雜,如果熟悉各個內容的話,可以直接跳到 3.6。

3.1 問題陳述

已知變量 x 服從某固定但未知的分布。x 與隱變量(latent variables)的關系可以用圖15 描述。這是一個簡單的概率圖。(注意,x 和 z 都是向量)

DAG 
圖15 兩層的有向概率圖,x為觀測變量,z為隱變量

對於這個概率圖,p(z) (隱變量 z 的先驗)、p(x|z)x 相對 z 的條件概率),及 p(z|x)(隱變量后驗)三者就可行完全描述 x和 z 之間的關系。因為兩者的聯合分布可以表示為:

 

p(z,x)=p(x|z)p(z)

 

x 的邊緣分布可以計算如下: 

p(x)=zp(x,z)dz=zp(x|z)p(z)dz=Ez[p(x|z)]

 

我們只能觀測到 x,而 z 是隱變量,不能被觀測。我們任務便是通過一個觀察集 X,估計概率圖的相關參數。

對於一個機器學習模型,如果它能夠(顯式或隱式的)建模 p(z) 和 p(x|z),我們就稱之為生成模型。這有如下兩層含義: 
1. 兩者決定了聯合分布 p(x,z); 
2. 利用兩者可以對 x 進行采樣(ancestral sampling)。具體方法是,先依概率生成樣本點 zip(z),再依概率采樣 xip(x|zi)

最簡單的生成模型可能是朴素貝葉斯模型

3.2 最大似然估計(Maximum Likelihood Estimation,MLE)

概率分布的參數最經典的方法是最大似然估計

給定一組觀測值

X=(xi), i=1,..,n

。觀測數據的似然為: 

L(pθ(X))=inpθ(xi)

 

一般取似然的對數: 

logL(pθ(X))=inlogpθ(xi)

 

MLE 假設最大化似然的參數 θ 為最優的參數估計。因此,概率的參數估計問題轉化為了最大化 logL(pθ(X)) 的最化問題。

從貝葉斯推理的觀點,θ 本身也是隨機變量,服從某分布 p(θ)

 

p(θ|X)=p(θ)(X|θ)p(X)=p(θ)(X|θ)θp(X,θ)dθp(θ)(X|θ)

 

 

logp(θ|X)=logp(θ)+logL(p(X|θ))

 

這是最大后驗概率估計(MAP)。

3.3 期望最大化算法(Expectation-Maximum,EM)

對於我們問題,利用 MLE 准則,優化目標為: 

logp(X,Z)

 

由於z 不可觀測, 我們只能設法優化: 

logp(X)=logzp(X,z)dz

 

通過 MLE 或 MAP 現在我們已經有了要目標(對數似然),但在我們問題下,似然中存在對隱變量 z的積分。合理假設(指定)p(z) 和 p(x|z)的分布形式,可以用期望最大化算法(EM)解決。

隨機初始化 θold 
E-step:計算 pθold(z|x) 
M-step:計算 θnew,給定: 

θnew=argmaxθQ(θ,θold)

其中, 
Q(θ,θold)=zpθold(z|x)log(pθ(x,z))dz

EM 比較直觀的應用是解決高斯混合模型(Gaussian Mixtrue Model,GMM)的參數估計及K-Means 聚類。更復雜的,語音識別的核心——GMM-HMM 模型的訓練也是利用 EM 算法[5]。

這里我們直接給出 ME 算法而省略了最重要的證明,但 EM 是變分推理的基礎,如果不熟悉建議先參見 [4] Chapter 9 或 [9]。

3. 4 MCMC

EM 算法中涉及到對 p(z|x) (即隱變量的后驗分布)的積分(或求各)。雖然上面舉的例子可以方便的通過 EM 算法求解,但由於概率分布的多樣性及變量的高維等問題,這個積分一般是難以計算的(intractable)。

因此,可以采用數值積分的方式近似求得 M-step 的積分項。 

Q(θ,θold)=zpθold(z|x)log(pθ(x,z))dz1Ni=1Nlogpθ(x,zi)

 

這涉及到按照 p(z|x) 對 z 進行采樣。這需要用到 MCMC 等采樣技術。關於 MCMC,LDA數學八卦 0.4.3 講得非常明白,這里不再贅述。也可以參考 [4] Chapter 11。

3.5 變分推理(Variational Inference,VI)

由於 MCMC 算法的復雜性(對每個數據點都要進行大量采),在大數據下情況,可能很難得到應用。因此,對於

p(z|x)

的積分,還需要其他的近似解決方案。

 

變分推理的思想是,尋找一個容易處理的分布 q(z),使得 q(z) 與目標分布p(z|x) 盡量接近。然后,用q(z) 代替 p(z|x)

分布之間的度量采用 Kullback–Leibler divergence(KL 散度),其定義如下:

KL(q||p)=q(t)logq(t)p(t)dt=Eq(logqlogp)=Eq(logq)Eq[logp]

在不致引起歧義的情況下,我們省略 E 的下標。這里不加證明的指出 KL 的一些重要性質:KL(q||p)0 且 KL(q||p)=0q=p [6]

注:KL散度不是距離度量,不滿足對稱性和三角不等式

因此,我們尋找 q(z) 的問題,轉化為一個優化問題: 

q(z)=argmaxq(z)QKL(q(z)||p(z|x))

 

KL(q(z)||p(z|x)) 是關於 q(z) 函數,而 q(z)Q 是一個函數,因此,這是一個泛函(函數的函數)。而變分(variation)求極值之於泛函,正如微分求極值之於函數。 
如果對於變分的說法一時不好理解,可以簡單地將變分視為高斯分布中的高斯、傅里葉變換中的傅里葉一樣的專有名詞,不要嘗試從字面去理解。 
另外不要把變分(variation)與 variable(變量), variance(方差)等混淆,它們之間沒有關系。

ELBO(Evidence Lower Bound Objective)

根據 KL 的定義及 p(z|x)=p(z,x)p(x) 

KL(q(z)||p(z|x))=E[logq(z)]E[logp(z,x)]+logp(x)

 

令 

ELBO(q)=E[logp(z,x)]E[logq(z)]


根據 KL 的非負性質,我們有 

logp(x)=KL(q(x)||p(z|x))+ELBO(q)ELBO(q)

 

ELBO 是 p(x) 對數似似然(即證據,evidence)的一個下限(lower bound)。

對於給定的數據集,p(x) 為常數,由 

KL(q(x)||p(z|x))=logp(x)ELBO(q)


最小化 KL 等價於最大化 ELBO 。

 

關於變分推理這里就簡單介紹這么多。有興趣的話可以參考 [6]、[4] Chapter 10 以及最新的 tutorial [10]。

3.6 VAE

這里主要是按照 [1] 的思路來討論 VAE。

觀測數據 x(i) 的對數似然可以寫作: 

logpθ(x(i)=KL(qΦ(z|x(i))||pθ(z|x(i)))+L(θ,Φ;x(i)))

 

這里我們將 ELBO 記作 L,以強調需要優化的參數。 
我們可以通過優化 L,來間接的優化似然。

VI 中我們通過優化 L 來優化 KL。

根據概率的乘法公式,經過簡單的變換,L 可以寫作 

L(θ,Φ;x(i)))=KL(qΦ(z|x(i))||pθ(z))+EqΦ(z|x)[logpθ(x(i)|z)]

 

因此,我們優化的目標可以分解成等號右邊的兩項。

3.6.1 第一項

我們先考察第一項,這是一個 KL 散度。q 是我們要學習的分布,p 是隱變量的先驗分布。通過合理的選擇分布形式,這一項可以解析的求出。

如果,q 取各維獨立的高斯分布(即第1部分的 decoder),同時令 p 是標准正態分布,那么,可以計算出,兩者之間的 KL 散度為: 

KL(qΦ(z|x(i))||pθ(z))=0.5(1+logσ2iμ2iσ2i)=0.5(1+logσ2iμ2iexp(logσ2i))

 

這就是本文第1部分目標函數的 KL 項了。

具體證明見 [1] 附錄B。

3.6.2 第二項

然后,我們考察等式右邊第二項。EqΦ(z|x)[logpθ(x(i)|z)] 是關於 x(i) 的后驗概率的對數似然。

由於 VAE 並不對 q(z|x) (decoder) 做太強的假設(我們的例子中,是一個神經網絡),因引,這一項不能解析的求出。所以我們考慮采樣的方式。 

EqΦ(z|x)[logpθ(x(i)|z)]1Lj=1Llogpθ(x(i)|z(j))


這里 z(j) 不是通過從 decoder 建模的高斯分布直接采樣,而是使用了第1部分介紹的 reparameterization 方法,其正確性證明見[1]的2.3小節。

 

如果每次只采一個樣本點,則 

EqΦ(z|x)[logpθ(x(i)|z)]logpθ(x(i)|z~)

 

其中,z~ 為采樣點。很幸運,這個式正是神經網絡常用的損失函數。

3.6.3 損失函數

通過上面討論,VAE 的優化目標都成為了我們熟悉並容易處理的形式。下面,我們針對 pθ(x(i)|z~)(encoder)的具體建模分布,推導下神經網絡訓練中實際的損失函數。

第1部分介紹了 交叉熵和均方誤差兩種損失函數。下面簡單介紹下,兩種損失對應的不同概率分布假設。以下分布均假設 x 的各維獨立。

交叉熵

如果假設 p(xi|z),(i=1,..,n) 服從伯努力分布,即: 

p(x=1|z)=αz,p(x=0)=1αz


對於某個觀測值,其似然為: 

L=αxz(1αz)1x

 

decoder 輸出為伯努力分布的參數,即 αz=decoder(z)=x^。則對數似然為: 

logL=xlog(x^)+(1x)log(1x^)


logL 這就是我們使用的交叉熵。

 

均方誤差

如果假設

p(xi|z),(i=1,..,n)

服務高斯分布,即 

p(x|z)=12π−−√σe(xμ)22σ2

 

對數似然為:

 

logL=0.5log(2π)0.5logσ(xμ)22σ2

 

decoder 為高斯分布的期望,這里不關心方差,即σ為未知常數。我們是優化目標為(去掉與優化無關的常數項): 

max(xμ)22σ2=min(xμ)2


這就是我們要優化的均方誤差。

 

對不同損失函數與概率分布的聯系,詳細討論見 [4] Chapter 5。

4. 結語

對這個領域的接觸不多,認識淺顯,文獻也讀的少,更多的是一些疑問:

  • VAE 是非常漂亮的工作,是理論指導模型結構設計的范例。

  • [1] [2] 獨立提出 VAE。雖然最后提出的算法大致相同,但出發點和推導思路還是有明顯不同,應該放在一起相互參照。

  • VAE 作為一種特征學習方法,與同樣是非監督的 AE、 RBM 等方法相比,優劣勢分是什么?

  • [2] 討論了與 denoising AE 的關系,但 VAE 在形式上與 constractive auto-encoder 更相似,不知道兩者的關系如何理解。

  • 有些工作利用 VAE 作半監督學習,粗略看了些,並沒有展現出相比於其他預訓練方法的優勢[3, 12]。

  • 結合上面幾點,雖然 VAE 是一個很好的工具,新的“論文增長點”,但僅就深度學習而言,感覺僅僅只是另一種新的工具。


Refences

  1. Kingma et al. Auto-Encoding Variational Bayes.
  2. Rezende et al. Stochastic Backpropagation and Approximate Inference in Deep Generative Models.
  3. Kingma and Rezende et al. Semi-supervised Learning with Deep Generative Models.
  4. Bishop. Pattern Recognition and Machine Learning.
  5. Young et al. HTK handbook.
  6. Blei et al. Variational Inference: A Review for Statisticians.
  7. Doersch. Tutorial on Variational Autoencoders.
  8. Kevin Frans. Variational Autoencoders Explained.
  9. Sridharan. Gaussian mixture models and the EM algorithm.
  10. Blei et al. Variational Inference: Foundations and Modern Methods.
  11. Durr. Introduction to variational autoencoders .
  12. Xu et al. Variational Autoencoders for Semi-supervised Text Classification.

Further Reading


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM