貝葉斯推斷
由上一篇我們已經了解到,對於未知的分布或者難以計算的問題,我們可以通過變分推斷將其轉換為簡單的可計算的問題來求解。現在我們貝葉斯統計的角度,來看一個難以准確計算的案例。
推斷問題可以理解為計算條件概率$p(y|x)$。利用貝葉斯定理,可以將計算條件概率(或者說后驗概率,posterior)轉換為計算聯合概率(joint distribution)和先驗概率(prior):
\begin{equation} p(y|x) = p(y,x) / p(x) \label{1.3} \end{equation}
其中$p(y,x)$為聯合概率密度,$p(y,x)=p(x|y)p(y)=p(y|x)p(x)$,$p(x)$為邊緣概率密度(marginal density),又稱為evidence,或者x的先驗。邊緣概率由對聯合概率的其他值做積分得到:
\begin{equation} p(x) = \int{p(y,x)} dy \nonumber \end{equation}
式$(\ref{1.3})$的方法就是貝葉斯推斷。朴素貝葉斯模型(naïve Bayesian model)正是基於貝葉斯推斷的模型。
朴素貝葉斯是文本分類的經典模型。在文本分類任務中,$y$為文本的類別,$x$為文本,文本有特征$x_1⋯x_n$,這些特征可能是某個詞或詞組,那么:
\begin{align} p(y,x) &= p(y)p(x | y) \nonumber \\ &= p(y)p(x_1,⋯,x_2 |y) \nonumber \end{align}
朴素貝葉斯中假設特征$x_1⋯x_n$是相互獨立的,所以:
\begin{equation} p(y,x) = p(y) \prod_i {p(x_i |y)} \nonumber \end{equation}
在進行推斷時,$p(x)$會被忽略,因為文本的特征的分布對文本的類別沒有影響,對所有文本,它都是固定的$p(x)$,也就是說$(\ref{1.3})$中等號右邊的分母對所有文本類型來說都相同,判斷文本類型只需要分子部分。所以要計算p(y│x),只需要計算$p(y)$和$p(x_i |y)$。$p(y)$和$p(x_i |y)$的計算如下:
\begin{align} &p(y) = \frac{|D_y |}{|D|} \label{1.4} \\ &p(x_i | y) = \frac{|D_(y,x_i )|}{|D_y|} \nonumber \end{align}
其中$D$為文本總數,$D_y$為某個類別的文本的數量,而$D_(y,x_i )$為屬於某個類別且包含特征$x_i$的文本的數量。
$p(x)$的問題
上面用朴素貝葉斯進行文本分類的例子中,我們忽略了$p(x)$,從而使得貝葉斯推斷可行。對於一些模型,$p(x)$是不能被忽略,而且計算$p(x)$的時間復雜度非常高,在這種情況下,貝葉斯推斷$p(y|x)$不能直接求解。例如,我們假設數據$x$服從某種混合高斯分布(mixture of Gaussian):
\begin{align} &x_i |c_i,\mu \sim N(c_i^T \mu,1) \nonumber \\ &c_i \sim Categorical(\frac{1}{K},⋯,\frac{1}{K}) \nonumber \\ &\mu_k \sim N(0, \sigma^2) \nonumber \end{align}
其中$x_i$表示$n$個數據中的某個數據,它服從高斯分布$N(c_i^T \mu,1)$,$c_i$是一個向量,向量中元素的取值為0或1,點乘$c_i^T \mu$表示取均值向量$\mu$中的某個均值,例如$mu_k$,而$mu_k$服從高斯分布$N(0,\sigma^2)$,$\sigma^2$是一個超參數(hyperparameter),由人為設定。我們可以想象對所有數據$x_i$進行聚類(clustering),並假設共有$K$類(cluster),那么某個數據$x_i$所屬的類$c_i$為$0~K$中的某一類$k$,且屬於第$k$類的概率為$\frac{1}{K}$,而第$k$類的數據的均值是$\mu_k$,方差設為1(可以假設$x_i$都是經過歸一化處理的數據,值落在[0,1]區間)。對於這樣的數據,我們可以用高斯混合模型(Gaussian Mixture Model,GMM)對其進行聚類——一種迭代的方法,和EM(Expectation Maximization,期望最大)方法有關聯,以后會介紹EM——數據$x$是觀測量,而它所屬的類以及類的均值是隱變量$y={c_i,\mu}$,所以聚類也是在求$p(y|x)$。
上面描述的數據$x$,它與$c_i$和$\mu$的聯合概率密度函數為:
\begin{equation} p(x,c,\mu) = p(\mu) \prod_{i=1}^n{p(c_i )p(x_i|c_i,μ)} \label{1.5} \end{equation}
其中$c={c_1,⋯,c_n}$。因為數據$x$是獨立同分布的,$c_i$是$x_i$的局部變量,也就是說一個$x_i$對應一個$c_i$,與其他$c_i$沒有關系,而$\mu$是一個全局變量,與$i$無關,所以不能在乘積中被拆分為$\mu_i$。計算聯合密度$p(x,c,\mu)$時,需要計算所有數據對應的$p(c_i)p(x_i |c_i,\mu)$的乘積,因為是$\mu$是一個全局變量,不能拆分,所以聯合密度的計算為$(\ref{1.5})$的形式。
因為$c_i$的先驗是離散的(discrete),像1、2、3這樣的“孤立”值,所以計算$x$的邊緣概率密度$p(x)$時,要通過求和覆蓋$c_i$可能的取值,而$\mu$是連續的(continuous),要對其積分,所以$x$的邊緣概率密度為:
\begin{equation} p(x) = \int{p(\mu) \prod_{i=1}^n{\sum_{c_i}{p(c_i )p(x_i│c_i,\mu)}}} dμ \label{1.6} \end{equation}
假設數據$x$有$K$個類,$(\ref{1.6})$的時間復雜度(time complexity)就是$O(K^n)$。這是一個指數級的復雜度。隨着數據量$n$的增加,計算將無法完成。在這種情況下,我們是無法用貝葉斯推斷$(\ref{1.3})$去計算后驗$p(y|x)$的。
ELBO
上面我們已經從貝葉斯統計的角度了解到,條件概率$p(y|x)$難以計算的原因在於$x$的先驗$p(x)$,下面我們還會看到,在以KL散度作為變分分布訓練的目標函數時,$p(x)$同樣會給計算帶來麻煩。
當我們設計一個函數$q(y)$或者$q(y|x)$作為變分函數,然后使它接近真實的后驗時,我們可以用KL散度來衡量它和真實分布$p(y|x)$的差異。,假設部分分布是$q(y)$,那么變分推斷的目標是求最優的變分分布函數$q^* (y)$:
\begin{equation} q^* (y) = \argmin_{q(y)\in L} {KL(q(y) \parallel p(y|x))} \label{1.7} \end{equation}
其中$\argmin_{q(y) \in L}$表示從$L$中找出使得KL散度最小的$q(y)$,而$L$是所有可能的$q(y)$的族(family)。族的復雜度決定了優化問題(optimization)的復雜度。
其實,無論選擇什么樣的族,$(\ref{1.7})$仍然是難以計算的,因為KL散度的計算如下:
\begin{equation} KL(q(y) \parallel p(y|x)) = E_q [\log{q(y)} ] - E_q [\log{p(y|x)} ] \label{1.8} \end{equation}
而就像前面分析的,$p(y|x)$是計算的困難所在。為了使優化問題能夠計算,我們需要將$(\ref{1.8})$替換掉。
首先,我們對$(\ref{1.8})$做一些轉換:
\begin{align} KL(q(y) \parallel p(y|x)) &= E_q [\log{q(y)} ] - E_q [\log{\frac{p(y,x)}{p(x)}} ] \nonumber \\ &= E_q [\log{q(y)}] - E_q [\log{p(y,x)} ]+ E_q [\log{p(x)} ] \label{1.9} \end{align}
調整$(\ref{1.9})$中各項的位置,得到:
\begin{align} &E_q [\log{p(x)} ]= KL(q(y) \parallel p(y|x))- E_q [\log{q(y)} ] + E_q [\log{p(y,x)}] \nonumber \\ &\log{p(x)} = KL(q(y) \parallel p(y|x)) + ELBO(q) \label{1.10} \end{align}
其中$E_q [\log{p(x)} ]=\log{p(x)}$,因為從分布$q(y)$中對$y$采樣,與$x$無關,$\log{p(x)} E_q [1]=\log{p(x)}$。$(\ref{1.10})$中$ELBO$(Evidence Lower BOund)是log evidence的下界,也就是 $\log{p(x)}$的下界,因為KL散度的值只能大於等於0。當$q(y)$等於$p(y|x)$時,KL散度為0,$\log{p(x)}=ELBO(q)$,否則$ELBO(q)$小於$(\ref{1.10})$等號左邊。
實際上,我們只需要優化$ELBO$,因為:
\begin{equation} KL(q(y) \parallel p(z│x)) = \log{p(x)}- ELBO(q) \nonumber \end{equation}
而$\log{p(x)}$對優化目標$q$來說是一個常數,不影響$q$的優化。對$ELBO$還可以以下面的方式表示:
\begin{align} ELBO(q) &= E_q [\log{p(y,x)} ]- E_q [\log{q(y)} ] \nonumber \\ &= E_q [\log{p(x|y)} ] + E_q [\log{p(y)} ] - E_q [\log{q(y)} ] \nonumber \\ &= E_q [\log{p(x|y)} ] - KL(q(y) \parallel p(y)) \label{1.11} \end{align}
式$(\ref{1.11})$等號右邊第一項是expected log-likelihood,而第二項是隱變量$y$的變分分布與真實分布的KL散度,其中變分分布$q(y)$是優化的目標。
將原先的目標函數KL散度替換為$ELBO$之后,我們就可以通過優化,找到$q$來進行推斷了(提醒一下,變分推斷並不局限於貝葉斯推斷)。
平均場
上面我們已經得到了變分推斷模型的目標函數(objective function),也就是$ELBO$,現在我們來看一下如何選擇變分分布$q(y)$的族(family)。
在上一篇【變分近似】中,我們了解到論文$[2]$的作者邁克爾·喬丹等人采用Frenchel共軛來確定$q(y)$。 Frenchel共軛要求預先知道$p(x|y)$的表達式,才能將p(x|y)轉換為q(y)。這里我們來看。另一種構造變分推斷的方法——平均場(mean-field)方法。
在平均場變分族中,隱變量是相互獨立的,所以平均場的族可以統一表示為:
\begin{equation} q(y) = \prod{q_i (y_i)} \nonumber \end{equation}
例如,服從混合高斯分布的數據$x$的隱變量$\mu$和$c$滿足:
\begin{equation} q(\mu,c) = \prod_k{q(\mu_k|m_k,s_k^2 )} \prod_i{q(c_i|φ_i )} \label{1.12} \end{equation}
其中$\mu_k$是第$k$類的數據的均值,它服從高斯分布$N(m_k,s_k^2)$,例如$N(0,\sigma^2)$,而$c_i$是第$i$個數據對應的類,它服從多項分布,例如$Categorical(\frac{1}{K},⋯,\frac{1}{K})$,$\varphi_i=\frac{1}{K}$。此時,隱變量被替換為$m_k$、$s_k^2$和$\varphi_i$等變分參數。變分推斷模型的學習過程就是優化這些參數,從而最大化$ELBO$。
mean-field是非常簡單的構造變分推斷的方法,它假設所有變量是相互獨立。這種簡單的方法有一個缺點,就是它無法反映隱變量之間的關系。另外,論文$[1]$中指出,mean-field方法會低估真實分布的方差(variance),因為KL散度是不對稱的,$KL(q \parallel p)$不等於$KL(p \parallel q)$,而$KL(q \parallel p)$會對$q$不落在分布$p$的懲罰很大,而忽略$q$沒有覆蓋$p$的情況,因為:
當$q \to 0$而$p \to 1$,此時$q \log{\frac{q}{p}} \to 0$,KL的影響趨於零;
當$q \to 1$而$p \to 0$,此時$q \log{\frac{q}{p}} \to + \infty$,KL的值非常大。
這在【GAN基礎】中有分析,個人認為它是KL散度的問題,而不是mean-field本身的缺陷。
CAVI
現在我們已經知道了變分推斷的目標函數$ELBO$,包括其中的變分分布$q(y)$,現在我們來理解變分參數的求解過程。論文$[1]$中采用的是CAVI(Coordinate Ascent Variational Inference)算法。
CAVI是一種迭代方法。在CAVI中,每次只更新一個隱變量,其他隱變量作為輸入求這個變量,也就是$p(y_j |y_{-j},x)$,其中$y_j$是要計算的隱變量,$y_{-j}$是其他隱變量。例如在$(\ref{1.12})$中,$y_j$是$m_k$,而$y_{-j}$代表$s_k^2$、$\varphi_i$,以及$m_{-k}$。對式$(\ref{1.11})$的第一行進行調整,使它只優化$q_j$,也就是固定其他隱變量,計算$y_j$,並使ELBO最大化:
\begin{equation} ELBO(q_j )= E_j [E_{-j} [\log{p(y_j,y_{-j},x)}]]- E_j [\log{q_j (y_j )} ] + const. \nonumber \end{equation}
其中$p(y_j,y_{-j},x)$為各個變量的聯合分布密度,$const.$代表一個常數。
經過迭代計算,最終變分推斷會收斂於一個局部最優解(local optimum),此時,對於所有隱變量,$ELBO$都是最優。
CAVI雖然能夠求解變分推斷,但是它的不適用於大規模的數據集,因為它要求每一次迭代時都計算一遍所有的數據,在下一篇,我們將會了解一種隨機優化的方法,它可以除了大規模的數據集。
[1] Blei, D. M., Kucukelbir, A., McAuliffe, J. D. (2018). “variational inference a review for statisticians”.
[2] Jordan, M. I., Ghahramani, Z., Jaakkola, T., and Saul, L. (1999). “Introduction to variational methods for graphical models”.