變分推斷(一)


引言
GAN專題介紹了GAN的原理以及一些變種,這次打算介紹另一個重要的生成模型——變分自編碼器(Variational AutoEncoder,VAE)。但在介紹編碼器之前,這里會先花一點時間介紹變分推斷(Variational Inference,VI),而這一小系列最后還會介紹貝葉斯神經網絡——其中用到了變分推斷。所以,這一系列其實是一個變分推斷系列。

 

變分推斷基礎

所謂推斷,就是根據一定的知識或信息進行推演,然后做出某種判斷。在機器學習中,無論是分類、回歸還是聚類,都是推斷任務。例如分類任務接收數據,然后判斷它們的類別標簽。類別標簽或者回歸任務中的回歸值,又可以看成是數據的某種隱藏表征(latent representation)。這種表征可以用來描述數據,例如說“貓”的標簽說明圖像中有貓。實際上,神經網絡中的隱藏層也可以看成是某種表征,這些表征可以是像標簽一樣是解耦的(disentangled)——一個個“獨立”的——或者是非解耦的。

從統計概率來說,推斷就是基於“條件”$x$,通過條件概率$p(y|x)$去判斷某件事$y$發生的概率。但是這種從$x$到$y$的關系很多時候是未知的,或者它的解析解(analytical solution)是難以計算的——計算復雜度過高。這時候,我們就要通過一個近似解(approximate)$q$來逼近這個條件概率$p(y|x)$。

求數據分布的近似有兩個重要方法[1],其中一個是馬爾科夫鏈蒙特卡洛(Markov Chain Monto Carlo,MCMC)——MCMC是一個更傳統的方法,在介紹cGAN時有提到,后面會設一個小專題詳細介紹它(先挖個坑)——另一個是變分推斷。變分推斷和MCMC的目標都是求真實的概率密度函數(probability density)$p$的近似$q$,但MCMC是采樣(sampling)的方法,而變分推斷是通過優化(optimization)來求近似解。變分推斷的核心思想是從變分分布的族(family)中找出其中一部分“成員”$q$,使得$q$和$p$的KL散度(Kullback-Leibler divergence)$KL(q \parallel p)$盡可能小——其實也可以取KL散度外的其他的衡量指標。KL散度的數學表示為:

\begin{equation} KL(q \parallel p) = E_q [\log{q} - \log{p} ] \label{1.1} \end{equation}

其中$E_q$表示從$q$采樣,計算$\log{q} - \log{p}$的期望(expectation)$E$。當變分分布(variational distribution)$q$等於真實分布$p$時,KL散度等於0。如果用參數$\theta$來表示分布$q_{\theta}$——例如高斯分布$N(\mu,\sigma^2)$用參數均值$\mu$和方差$\sigma^2$來表示——那么變分推斷就是求參數$\theta$的值或范圍,使$q_{\theta}$近似$p$。這也是變分推斷被“變分”推斷的原因。

 

變分法、變分近似、變分推斷
變分推斷的變分來自變分法(calculus of variations)。變分又稱為Frechet微分,可以理解為無限維空間上的微分。對微分(differential)來說,當我們將$x$移動$dx$時,$f(x)$會變為$f(x+dx)$。而在變分中,自變量不是點$x$,而是函數$f(x)$。當函數$f(x)$改變時,它的泛函(functional,函數的函數)$F$的輸出$F(f(x))$也會發生改變。對一個函數$f$求極值(最大值maximum或最小值minimum),是在所有$x$中找到某個點,使$f(x)$取到極值;而在求一個泛函的極值時,我們是要找出$f$使泛函取到極值,例如在式$(\ref{1.1})$中,我們要找到$q$使得KL散度盡可能小。

 

變分近似
上面提到,因為概率密度$p(y|x)$未知或者難以計算,所以要采用近似的方法求解。論文$[2]$中,邁克爾·喬丹等人將變分推斷用於概率圖模型(probabilistic graphical model,或graphical model,圖模型)的計算——以后會專門設一個專題系統地介紹概率圖模型(第二個坑)。具體來說,邁克爾·喬丹等人通過變分近似(variational approximation),將原問題轉換為簡單的問題。將原問題轉換為新問題的過程時,新的參數會被引入——變分參數(variational parameters)。
論文$[2]$中提到的近似方法是利用了凸共軛函數(convex conjugate function,或者稱Frenchel conjugate)——在介紹f-GAN時提到過。這種變分近似的關鍵,在於問題的凸性(convexity)或者說凹性(concavity)——例如(圖1)所示,相對於黑色直線,紅色和藍色兩條曲線都只有一個極值——這一性質使得我們可以用原問題的上界(upper bound)或者下界(lower bound)來近似原問題。例如,我們可以用下面等號右邊來近似對數函數(logarithm function)$\log{x}$:
\begin{equation} \log{x} = \min_{\lambda} { {\lambda x- \log{\lambda} -1} } \nonumber \end{equation}
其中$ \lambda x - \log{\lambda} - 1$是以$x$為自變量,以$\lambda$斜率的線性函數。如$(圖2)$中的虛線所示,$\lambda$取不同值時,線性函數對應不同的虛線,但對確定的$x$來說,虛線上的$y$都要大於或等於對數曲線(實線)的值,也就是:
\begin{equation} \log{x} \leq \lambda x - \log{\lambda} - 1 \label{1.2} \end{equation}

變分近似就是要找出$\lambda$,使得對於某$x$$(\ref{1.2})$右邊的變分函數和左邊的原函數盡可能接近,而且因為右邊的線性函數是一個更容易求解的函數,計算它所需的計算量顯然要小於計算原函數。


(圖1,來自https://www.wallstreetmojo.com)

(圖2,來自《an introduction to variational methods for graphical models》)

 

對數函數$\log{x}$,它的變分變換是$\lambda x - \log{\lambda} - 1$,那其他函數我們如何確定它們的變分變換呢?采用Frenchel共軛方法,凹函數$f$表示為:
\begin{equation} f(x) = \min_{\lambda}{{\lambda^T x - f^* (\lambda)}} \nonumber \end{equation}
$\lambda^T$是$\lambda$的向量化表示方式。$f$的共軛函數:
\begin{equation} f^* (x) = \min_x⁡{{\lambda^T x - f(x)}} \nonumber \end{equation}
對於凸函數則有下面的共軛關系:
\begin{align} f(x) = \max_{\lambda}{{\lambda^T x - f^* (\lambda)}} \nonumber \\ f^* (x) = \max_x{{\lambda^T x - f(x)}} \nonumber \end{align}
可以發現,$\lambda^T x - f^* (\lambda)$是一個對於$x$的線性函數,而共軛函數$f^* (\lambda)$是線性函數的截距(intercept term)。但需要注意的是,凸共軛並不局限於線性邊界,我們還可以用其他類型的邊界,例如二次邊界:
\begin{align} &f(x) = \min_{\lambda}{{\lambda^T x^2 - \bar{f}^* (\lambda)}} \nonumber \\ &\bar{f}(x)=f(x^2) \nonumber \end{align}

總之,變分推斷就是用一個簡單的函數去近似原問題,從而是問題能夠求解。 

 

小結

對變分推斷有一個大致的概念后,現在我們梳理一下這一節的知識點,然后在下一篇介紹變分推斷的目標函數以及求解過程。

變分推斷要做的是根據$x$對$y$進行推斷,但是因為准確推斷難以計算,所以我們采用近似的辦法來求解。變分近似就是一種近似方法。通過它,我們能將原問題轉換為較為簡單的可計算的問題。在這個轉換過程中,這個方法會引入變分參數。不同的變分參數決定了函數族中函數,就像索引一樣。通過變分法,我們能找出使得變分轉換后的問題接近於原問題的參數,或者說函數。

 

 

未完待續~

 

 [1] Blei, D. M., Kucukelbir, A., McAuliffe, J. D. (2018). “variational inference a review for statisticians”.

 [2] Jordan, M. I., Ghahramani, Z., Jaakkola, T., and Saul, L. (1999). “Introduction to variational methods for graphical models”.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM