SVI
變分推斷的前兩篇介紹了變分推斷的構造方法、目標函數以及優化算法CAVI,同時上一篇末尾提到,CAVI並不適用於大規模的數據的情況,而這一篇將要介紹一種隨機優化(stochastic optimization)的方法。這種優化方法與隨機梯度下降(Stochastic Gradient Descent,SGD)方法有相近,它能夠處理大規模數據。通過這種方法進行優化的變分推斷,我們稱為隨機變分推斷(Stochastic Variational Inference,SVI)。(需要注意的是,這里介紹的是一種通用優化算法,並不局限於優化變分推斷)
隨機梯度下降
梯度下降是廣泛用於機器學習,尤其是深度學習模型訓練的優化算法之一——關於優化算法,以后會開一個專題來介紹。在處理大規模數據時,我們可以采用隨機梯度下降法,分批次地處理小規模數據。梯度下降法采用下面的方式優化模型的參數:
\begin{align} &\theta^{t+1} = \theta^t - \eta \frac{\partial f}{\partial \theta} \label{1.13} \\ &\frac{\partial f}{\partial \theta} = \begin{bmatrix} \frac{\partial f}{\partial \theta_1} & \frac{\partial f}{\partial \theta_2} & ⋯ & \frac{\partial f}{\partial \theta_k} \end{bmatrix}^T \nonumber \\ & \theta^t = \begin{bmatrix} \theta_1^t & \theta_2^t & ⋯ & \theta_k^t \end{bmatrix}^T \nonumber \end{align}
其中$\theta^t$是當前參數的值(一系列參數$\theta_1^t,\theta_2^t,⋯,\theta_k^t$組成的向量),$\theta^{t+1}$是第$t+1$次優化后的參數的值,$\eta$是超參數(hyper parameter)學習率(learning rate),由人設定,而$\frac{\partial f}{\partial \theta}$是函數$f$對參數$\theta$的梯度(或者說一階導數)。當移動$\Delta \theta$能夠使函數的值變小,也就是梯度為負值,那么參數$ \theta^t$就會向$\Delta \theta$方向移動,值變為$\theta^{t+1}$,這樣函數的值就會越來越小,最終得到局部最小值(如果函數是非凸函數,有多個極值,否則得到全局極值),如$(圖4、5)$所示。
(圖4,來自zhuanlan.zhihu.com/p/36564434)
(圖5,來自https://blog.csdn.net/zhulf0804/article/details/52250220)
黎曼測度
上面的方法是標准的梯度下降法,可以看到梯度的計算采用的是歐式距離(Euclidean distance,歐幾里得距離)。當參數$\theta_1^t$移動$\Delta \theta_1$,參數$\theta_2^t$移動$\Delta \theta_2$,其他參數以相同方式移動,那么函數$f$移動的歐式距離是:
\begin{align} d(\theta,\theta+\Delta \theta) = \sqrt{\sum{\Delta \theta_i}} = \sqrt{\Delta \theta^T \Delta \thetaθ} = \parallel \Delta \theta \parallel_2 \label{1.14} \end{align}
$(\ref{1.14})$式中$\Delta \theta^T \Delta \theta$表示向量點積(inner product),$\parallel \Delta \theta \parallel_2$是歐式范數(Euclidean norm),又稱L2范數。如$(圖5)$所示,從0點移動到1點,在x軸和y軸分別移動$\theta_1$和$\theta_2$。
但是,歐式距離並不適用於所有情況,因為參數可能並不在歐式空間(Euclidean space)中。例如,從$\theta_1^t$移動$\Delta \theta_1$到$\theta_1^{t+1}$,和從$\theta_1^{t+1}$移動$\Delta \theta_1$到$\theta_1^{t+2}$,從歐式距離來看都是移動了$\Delta \theta_1$,但在非歐空間中,兩個$\Delta \theta_1$可能是不同的。看$(圖6)$的例子,上下兩個圖都是從紅色分布移動到綠色分布,從均值(mean)來看,都是從-1變為1,移動了2,但是$(圖6)$的上面的圖中,在某種意義上,分布發生的變化要比下圖中的變化大。對這樣的情況,我們引入黎曼幾何(Riemannian geometry)。
在黎曼幾何中,兩點的距離不是通過歐式范數$(\ref{1.14})$來計算的,而是通過:
\begin{align} d(\theta, \theta+\Delta \theta) &= \sqrt{\sum_i{\sum_j{\Delta \theta_i \Delta \theta_j g_{i,j} (w)}}} \nonumber \\ &= \sqrt{\Delta \theta^T G(\theta) \Delta \theta} \label{1.15} \end{align}
其中$G(\theta)$是黎曼測度張量(Riemannian metric tensor),它由$\theta$決定——這里不做詳細的推導了,可以參考論文$[3]$的section 3的Example,論文中的式(15)。另外,當$G(\theta)$為單位矩陣時——從左上角到右下角的對角線上的值都為1,其他位置的值都為0的矩陣——$(\ref{1.15})$等於$(\ref{1.14})$,此時計算的是歐式距離。
(圖6,來自http://kvfrans.com/what-is-the-natural-gradient-and-where-does-it-appear-in-trust-region-policy-optimization/)
直覺上,$G(\theta)$描述了幾何空間對兩點間的路徑的影響。例如,在黎曼幾何的一個經典應用案例,廣義相對論中,光線在引力場中發生了彎曲,而不是直線行走。
Fisher信息矩陣
從$(圖6)$的例子可以看到,采用歐式距離來衡量概率分布的變化量不是一個好主意。分布的差異我們更多的是通過KL散度等指標來衡量。當$\theta$是分布的參數,而且我們用KL散度來衡量兩個分布的差異時,上面介紹的黎曼測度張量$G(\theta)$就是Fisher信息矩陣(Fisher Information Matrix,FIM):
\begin{equation} F=E_{p(\theta)} [\triangledown \log{p(\theta)} \triangledown \log{p(\theta)}^T ] \label{1.16} \end{equation}
其中$\triangledown$表示一階導。
為了證明在KL散度作為距離指標時,Fisher信息矩陣是黎曼測度張量$G(\theta)$,我們先來看一下KL散度的泰勒展式。泰勒展式的通用形式如下:
\begin{equation} f(x_0+ \Delta x)=f(x_0 )+ \Delta x f' (x_0 )+ \Delta x^2 f'' (x_0 )+⋯ \nonumber \end{equation}
其中$f'$為函數$f$的一階導,$f''$為函數的二階導,$\Delta x^2$為移動距離$\Delta x$的平方。等號右邊如果去掉省略號部分,則表示是二階泰勒展式,如果只保留前兩項,則是一階泰勒展式。KL的泰勒展式為:
\begin{align} KL(q(\theta+ \Delta \theta) \parallel p(\bar{\theta})) &=KL(q(\theta) \parallel p(\bar{\theta})) \nonumber \\ &+ (\triangledown_{\theta} KL(q(\theta) \parallel p(\bar{\theta})))^T \Delta \theta \nonumber \\ &+ \frac{1}{2} \Delta \theta^T \triangledown_{theta}^2 KL(q(\theta) \parallel q(\bar{\theta}))\Delta \theta + \dotsb \label{1.17} \end{align}
其中$\bar{\theta}$是固定的值,$\theta$才是自變量,$\triangledown_{\theta}$表示一階導,$\triangledown_{\theta}^2$表示二階導。簡化$(\ref{1.17})$,得到:
\begin{align} KL(q(\theta+\Delta \theta) \parallel q(\bar{\theta})) &\approx KL(q(\theta) \parallel p(\bar{\theta})) \nonumber \\ &+ \triangledown_{\theta} E_{q(\theta)} [\log{q(\theta)} ]^T \Delta \theta \nonumber \\ &- \frac{1}{2} \Delta \theta^T F \Delta \theta \label{1.18} \end{align}
其中$(\ref{1.17})$等號右邊第二項到$(\ref{1.18})$等號右邊第二項的推導如下:
\begin{align} \triangledown_{\theta} KL(q(\theta) \parallel q(\bar{\theta})) &= \triangledown_{\theta} E_{q(\theta)} [\log{q(\theta)}] - \triangledown_{\theta} E_{q(\theta)} [\log{p(\bar{\theta})} ] \nonumber \\ &= \triangledown_{\theta} E_{q(\theta)} [\log{q(\theta)}] = 0 \label{1.19} \end{align}
因為第一行等號右邊第二項中的$\log{q(\bar{\theta})}$對$\theta$是常數,求導結果為0。最終$(\ref{1.19})$為0,因為:
\begin{align} \triangledown_{\theta} E_{q(θ)} [\log{q(\theta)} ] &= E_{q(θ)} [∇_θ logq(θ) ] \nonumber \\ &= \int{q(\theta) \triangledown_{\theta} \log{q(\theta)}} d\theta \nonumber \\ &=\int{q(\theta) \frac{\triangledown_{\theta} q(\theta)}{q(\theta)} } d\theta \nonumber \\ &= \int{\triangledown_{\theta} q(\theta)} d\theta \nonumber \\ &=\triangledown_{\theta} \int{q(\theta) } d\theta = \triangledown_{\theta} E_{q(\theta)} [1]=0 \nonumber \end{align}
其中1的期望$E_q(\theta) [1]=1$,而常數的導數是0。關於這里期望$E_q(\theta)$ 和求導$\triangledown_{\theta}$換位的問題,可以根據中值定理(mean value theorem)和勒貝格控制收斂定理(dominated convergence theorem)推出:
\begin{align} \triangledown_{\theta} E_{q(\theta)} [\log{q(\theta)} ] &= \lim_{\Delta \theta \to 0}{\frac{1}{\Delta \theta} (E_q [\log{q(\theta+\Delta \theta)}] - E_q [\log{q(\theta)}])} \nonumber \\ &= \lim_{\Delta \theta \to 0}{E_q [ \frac{\log{q(\theta+\Delta \theta)}-\log{q(\theta)}}{\Delta \theta}]} \nonumber \\ &= \lim_{\Delta \theta \to 0}{E_q [\triangledown_{\theta} \log{q(\Theta(\Delta \theta))} ]} \nonumber \\ &= E_q [\lim_{\Delta \theta \to 0}{\triangledown_{\theta} \log{q(\Theta (\Delta \theta))}}] \nonumber \\ &= E_q [\triangledown_{\theta} \log{q(\theta)}] \nonumber \end{align}
其中第二行到第三行運用中值定理,第三行到第四行是控制收斂定理。
分析完了式$(\ref{1.17})$的第二項,我們再來看第三項到$(\ref{1.18})$的第三項的推導:
\begin{align} \triangledown_{\theta}^2 KL(q(\theta) \parallel p(\bar{\theta} )) &= \triangledown_{\theta}^2 E_{q(\theta)} [\log{q(\theta)}] - \triangledown_{\theta}^2 E_{q(\theta)}[\log{p(\bar{\theta})} ] \nonumber \\ &= \triangledown_{\theta}^2 E_{q(\theta)} [\log{q(\theta)}] \nonumber \\ &= E_{q(\theta)} [\triangledown_{\theta}^2 \log{q(\theta)} ] \label{1.20} \end{align}
其中期望$E_{q(\theta)}$內部為對數似然(log-likelihood)的Hessian矩陣$\triangledown_{\theta}^2 \log{q(\theta)}$,它可以進行如下變換:
\begin{align} (\log{q(\theta)})'' &= ((\log{q(\theta)})' )' \nonumber \\ &= (\frac{q' (\theta)}{q(\theta)})' \nonumber \\ &= \frac{q'' (\theta)q(\theta)- q' (\theta) q' (\theta)}{q(\theta)^2} \nonumber \\ &= \frac{q'' (\theta)}{q(\theta)} - \frac{q' (\theta)}{q(\theta)} \frac{q' (\theta)}{q(\theta)} \label{1.21} \end{align}
為了好看,這里對符號做了變換,將$\triangledown_{\theta}^2$換為$(·)''$,將$\triangledown_{\theta}$換為$(·)'$。將$(\ref{1.21})$帶入$(\ref{1.20})$,得到:
\begin{align} E_{q(\theta)} [\triangledown_{\theta}^2 \log{q(\theta)}] &= E_{q(\theta)} [\frac{q'' (\theta)}{q(\theta)} - \frac{q' (\theta)}{q(\theta)} \frac{q' (\theta)}{q(\theta)} ] \nonumber \\ &= E_{q(\theta)} [\frac{q'' (\theta)}{q(\theta)} ] - E_{q(\theta)} [(\log{q(\theta)} )' (\log{q(\theta) })' ] \nonumber \\ &= \int{q(\theta) \frac{q'' (\theta)}{q(\theta)}} d\theta - F \nonumber \\ &= \triangledown_{\theta}^2 \int{q(\theta)} d\theta - F \nonumber \\ &= -F \nonumber \end{align}
其中$F$為式$(\ref{1.16})$的Fisher信息矩陣,所以推出Fisher信息矩陣是KL散度的Hessian矩陣——Hessian矩陣是一個多元函數的二階偏導(second order partial derivative)$\triangledown_{x}^2$,形式如下:
\begin{equation} H(f) = \begin{bmatrix} \frac{\partial^2 f}{\partial x_1^2} & \frac{\partial^2 f}{\partial x_1 \partial x_2} & \dotsb & \frac{\partial^2 f}{\partial x_1 \partial x_n} \\ \frac{\partial^2 f}{\partial x_2 \partial x_1} & \frac{\partial^2 f}{\partial x_2^2} & \dotsb & \frac{\partial^2 f}{\partial x_2 \partial x_n} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial^2 f}{\partial x_n \partial x_1} & \frac{\partial^2 f}{\partial x_n \partial x_2} & \dotsb & \frac{\partial^2 f}{\partial x_n^2} \end{bmatrix} \nonumber \end{equation}
經過上面的推導,最終我們得到$(\ref{1.18})$。因為$p(\bar{\theta})$是常數,我們假定它與$q(\theta)$相同,那么$(\ref{1.18})$的等號右邊的第一項為0。再進一步簡化$(\ref{1.18})$,得到:
\begin{equation} KL(q(\theta+\Delta \theta)\parallel p(\bar{\theta})) \approx -\frac{1}{2} \Delta \theta^T F \Delta \theta \nonumber \end{equation}
觀察$(\ref{1.15})$,可以發現,此時黎曼測度張量$G(\theta)$是$-\frac{1}{2} F$。
自然梯度
標准的梯度下降法的計算采用的是式$(\ref{1.13})$,但是標准梯度下降法假設參數空間是歐式空間,而歐式空間並不適用於概率分布,所以上一節介紹了,在用KL散度來衡量分布的差異時,點$\theta$和$\theta+\Delta \theta$的距離如何表示。在上一篇我們還了解到,變分推斷采用的是$ELBO$來近似地衡量變分分布與真實分布的距離——但這里我就不對$ELBO$對應的$G(\theta)$進行推導了,感興趣的可以參考論文$[4]$——現在我們來了解如何用自然梯度(natural gradient)代替標准梯度,以及$G(\theta)$的作用。
自然梯度下降法如下所示:
\begin{align} \theta^{t+1} &= \theta^t - \eta \bar{\triangledown} L(\theta^t ) \nonumber \\ &= \theta^t - \eta G^{-1} (\theta) \triangledown L(\theta^t ) \label{1.22} \end{align}
其中$\bar{\triangledown} L(\theta^t )$是函數$L$的自然梯度,$\triangledown L(\theta^t )$是標准梯度,$\eta$是學習率,由人設定,$G^{-1}$是黎曼測度張量的逆。自然梯度$-\bar{\triangledown} L(\theta^t)$表示在黎曼空間中函數$L$的最速下降方向。對標准梯度,有:
\begin{equation} \triangledown L(\theta) = \frac{L(\theta+d \theta)-L(\theta)}{d \theta} \label{1.23} \end{equation}
其中$d\theta$表示參數$\theta$移動的距離,例如歐式距離。將$d\theta$表示為$d\theta = \varepsilon v$,其中$\varepsilon=|d\theta|$表示向量的長度,是一個很小的值——其實我們並不關心它,只關心$\theta$移動的方向——而$v$是單位向量,$|v|^2=\sum{g_{ij} v_i v_j}=v^T G(\theta)v=1$,表示$\theta$移動的方向。對$(\ref{1.23})$做一些調整:
\begin{align} &L(\theta+d\theta)= L(\theta) + \varepsilon \triangledown L(\theta)^T v \nonumber \\ &v^T G(\theta)v-1=0 \nonumber \end{align}
其中第二行為第一行的約束,約束$v$的取值范圍。
最速下降法(steepest descent method)(或最速上升)是要找到方向,使函數$L$從$\theta$移動到$\theta+d\theta$時它的值下降最快,所以對$L(\theta+d\theta)$求導找到極值點:
\begin{equation} \frac{\partial L(\theta+d\theta)}{\partial v} = \frac{\partial}{\partial v} [L(\theta)+ \varepsilon \triangledown L(\theta)^T v]=0 \nonumber \end{equation}
通過拉格朗日法把對$v$的約束加入優化:
\begin{align} &\frac{\partial}{\partial v} [L(\theta)+ \varepsilon \triangledown L(\theta)^T v+ \lambda(v^T G(\theta)v-1)]=0 \nonumber \\ &0+ \varepsilon \triangledown L(\theta)^T+2 \lambda G(\theta)v-0=0 \nonumber \\ &\triangledown L(\theta)^T+2 \lambda G(\theta)v=0 \nonumber \end{align}
其中第一行等號左邊第一項和最后一項與$v$無關,$v$是常數,求導后都為0;第二行的$\varepsilon$可以除掉,因為等號右邊為0,左邊第二項也有一個超參數$\lambda$,因此去除后無影響。結果整理,最終我們關心的最速下降(或上升)的移動方向為:
\begin{equation} v = -\frac{1}{2} \lambda G^{-1} (\theta) \triangledown L(\theta) \nonumber \end{equation}
因為$\lambda$是超參數,可以並入學習率,所以最終得到自然梯度:
\begin{equation} \bar{\triangledown} L(\theta) = G^{-1} (\theta) \triangledown L(\theta) \nonumber \end{equation}
因此得到式$(\ref{1.22})$,而且可以發現,當$G(\theta)$為單位矩陣時,自然梯度等於標准梯度。這與$(\ref{1.15})$對應,當$G(\theta)$為單位矩陣時,采用的是歐式距離。所以可以看出,標准梯度計算的也是最速下降(或上升)方向。
隨機優化
了解了求自然梯度就是求標准梯度$\triangledown L(\theta)$以及$G(\theta)$后,我們來到我們的最終目標——隨機自然梯度下降法。在隨機梯度下降法中,我們可以只取一部分數據進行計算,此時模型的目標函數是:
\begin{equation} L(x)= \frac{1}{n} \sum_{i=1}^n{L(x_i)} \nonumber \end{equation}
其中$n$為這部分數據的數據量,$x_i$和$y_i$是第$i$個數據。目標函數的梯度為:
\begin{equation} \triangledown_{\theta} L(x) = \frac{1}{n} \sum_{i=1}^n{\triangledown_{\theta} L(x_i )} \nonumber \end{equation}
這是標准梯度,計算自然梯度我們還要通過下面的式子求$G(θ)$:
\begin{equation} G(x|\theta) = \frac{1}{n} \sum_{i=1}^n{G(x_i |\theta)} \nonumber \end{equation}
總結
變分推斷這個專題總共包括三篇文章。第一篇文章介紹了變分法以及變分近似的概念,並且讓我們知道,可以通過變分推斷來解決那些難以計算的問題。第二篇文章以貝葉斯推斷為例子,分析了為什么一些問題難以准確求解,並介紹了一種構造變分推斷的方法(基於平均場定理)、變分推斷的目標函數($ELBO$)以及優化算法。最后這篇介紹了隨機變分推斷(SVI)——通過stochastic的方法來優化變分推斷模型。通過三篇文章,我們對變分推斷應該是有了一個比較全面的認識。下面兩篇文章,我們將來了解變分推斷在變分自編碼器(Variational AutoEncoder,VAE)和貝葉斯神經網絡(Bayesian Neural Network,BNN)中的應用。
完結
[1] Blei, D. M., Kucukelbir, A., McAuliffe, J. D. (2018). “variational inference a review for statisticians”.
[2] Jordan, M. I., Ghahramani, Z., Jaakkola, T., and Saul, L. (1999). “Introduction to variational methods for graphical models”.
[3] Amari, S., Douglas, S. C. (1998). “why natural gradient”.
[4] Hoffman, M. D., Blei, D., Wang, C., and Paisley, J. (2013). “stochastic variational inference”.