概
當前生成模型, 要么依賴對抗損失(GAN), 要么依賴替代損失(VAE), 本文提出了基於score matching 訓練, 以及利用annealed Langevin dynamics推斷的模型, 思想非常有趣.
主要內容
Langevin dynamics
對於分布\(p(x)\), 我們可以通過下列方式迭代生成
其中\(\tilde{x}_0 \sim \pi(x)\)來自一個先驗分布, \(z_t \sim \mathcal{N}(0, I)\). 當步長\(\epsilon \rightarrow 0\)並且\(T \rightarrow +\infty\)的時候, \(\tilde{x}_T\)可以認為是從\(p(x)\)中采樣的樣本.
注: 一般的Langevin, dynamics還需要在每一次迭代后計算一個接受概率然后判斷是否接受, 不過在實際中這一步往往可以省略.
Score Matching
通過上述的迭代可以發現, 我們只需要獲得\(\nabla_x \log p(x)\)即可采樣\(x\), 我們可以期望通過下面的方式, 通過一個網絡\(s_{\theta}(x)\)來逼近\(\nabla_x \log p_{data}(x)\):
但是在實際中, 先驗\(\log p_{data}(x)\)也是未知的, 幸運的是上述公式等價於:
注: 見 score matching
Denoising Score Matching
一個共識是, 所獲得的數據往往是一個低維流形, 即其內在的維度實際上很低. 所以\(\mathbb{E}_{p_{data}(x)}\)在實際中會出現高密度的區域估計得很好, 但是低密度得區域估計得非常差. Denosing Score Matching提高了一個較為魯棒的替代方法:
當優化得足夠好的時候,
實際中, 通常取\(q_{\sigma}(\tilde{x}|x) = \mathcal{N}(\tilde{x}|x, \sigma^2 I)\), 相當於在真實數據\(x\)上加了一個擾動, 當擾動足夠小(\(\sigma\)足夠小)的時候, \(q_{\sigma}(x) \approx p_{data}(x)\), 則\(s_{\theta^*}(x) \approx \nabla_x \log p_{data}(x)\).
注: 為啥期望部分要有\(p_{data}\)? 實際上上述目標和score matching依舊是等價的.
Noise Conditional Score Networks
Slow mixing of Langevin dynamics
假設\(p_{data}(x) = \pi p_1(x) + (1 - \pi)p_2(x)\), 且\(p_1, p_2\)的支撐集合是互斥的, 那么 \(\nabla_{x} \log p_{data}(x)\)要么為\(\nabla_{x} \log p_{1}(x)\)或者\(\nabla_{x} \log p_{2}(x)\), 與\(\pi\)沒有絲毫關聯, 這會導致訓練的結果與\(\pi\)也沒有關聯. 在實際中, 若\(p_1, p_2\)近似互斥, 也會產生類似的情況:
如上圖所示, 通過Langevin dynamics采樣的點幾乎是1:1的, 這與真實的分布便有了出入.
作者的想法是, 設計一個noise conditional score networks:
給定不同的\(\sigma\)其擬合不同擾動大小的\(p_{\sigma}\), 在采樣中, 首先用大一點的\(\sigma\), 然后再逐步縮小, 這便是一種退火的思想. 顯然, 一開始用大一點的\(\sigma\)能夠為后面的采樣提供更好更魯棒的初始點.
損失函數
設定\(\sigma_i, i=1,2,\cdots, L\), 且滿足:
即一個等比例(縮小)的數列.
對於每個\(\sigma\)采用如下損失:
注: \(\nabla_{\tilde{x}} q_{\sigma}(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}\).
於是總損失為
\(\lambda(\sigma_i)\)為權重系數.
Annealed Langevin dynamics
Input: \(\{\sigma_i\}_{i=1}^L, \epsilon, T\);
- 初始化\(x_0\);
- For \(i=1,2,\cdots, L\) do:
- \(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\);
- For \(t=1,2,\cdots, T\) do:
- 采樣\(z_t \sim \mathcal{N}(0, I)\);
- \(x_t \leftarrow x_{t-1} + \frac{\alpha_i}{2}s_{\theta}(x_{t-1}, \sigma) + \sqrt{\alpha_i} z_t\);
- \(x_0 \leftarrow x_T\);
Output: \(x_T\).
細節
-
關於參數\(\lambda(\sigma)\)的選擇:
作者推薦選擇\(\lambda(\sigma) = \sigma^2\), 因為當優化到最優的時候, \(\|s_{\theta}(x, \sigma)\|_2 \propto 1 / \sigma\), 故\(\sigma^2 \ell(\theta;\sigma) = \frac{1}{2}\mathbb{E}[\|\sigma s_{\theta}(x, \sigma) + \frac{\tilde{x} - x}{\sigma} \|_2^2]\), 其中\(\sigma s_{\theta}(x, \sigma) \propto 1, \frac{\tilde{x} - x}{\sigma} \sim \mathcal{N}(0, I)\), 故\(\sigma^2 \ell_{\theta,\sigma}\)與\(\sigma\)無關. -
關於\(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\):
對於一次Langevin dynamic, 其獲得的信息為: \(\frac{\alpha_i}{2} s_{\theta}(x_{t-1}, \sigma)\), 其噪聲為\(\sqrt{\alpha_i}z_t\), 故其信噪比(signal-to-noise)為(應該是element-wise的計算?)
當我們按照算法中的取法時, 我們有
故采用此策略能夠保證SNR保持一個穩定的值.