[學習筆記] Gibbs Sampling


Gibbs Sampling

Intro

Gibbs Sampling 方法是我最近在看概率圖模型相關的論文的時候遇見的,采樣方法大致為:迭代抽樣,最開始從隨機樣本中抽樣,然后將此樣本作為條件項,按條件概率抽樣,每次只從一個維度考慮,當所有維度均采樣完,開始下一輪迭代。

Random Sampling

假設我們一直一個隨機變量的概率密度函數,我們如何采樣得到服從這個分布的樣本呢?

學矩陣論的時候,老師教我們用反函數來生成任意概率分布的隨機數,因此,我們也可以用反函數法來生成該分布的樣本。即假設 $ \xi $ 是 $ [0,1] $ 區間上均勻分布的隨機變量,則其反函數$ cdf^{-1}( \xi ) $ 服從該概率密度函數為 $ p(x) $ 的分布。

有一個問題就是,當 $ p(x) $ 復雜到其累積分布函數的反函數無法計算的時候,或者不知道 $ p(x) $ 的精確值的時候,如何采樣呢?

這時候就要用到一些采樣的策略,比如拒絕采樣、重要性采樣、Gibbs采樣等等。下面就記一下各種采樣策略。

Rejection Sampling

拒絕采樣的原理是,已知一個提議分布q(往往是簡單分布)和原始分布p,從提議分布中采樣一個樣本\(\hat{x}\),然后計算接受率\(a(\hat{x}) = \frac{p(\hat{x}}{kq(\hat{x})}\),然后從均勻分布中生成一個值z,如果z小於等於a,則接受樣本,否則不接受樣本,繼續采樣,知道采樣到了足夠的樣本。

這個圖應該可以說明,上面藍色的線是提議分布,必須包含原始分布,然后在z0處計算接受率即可。

然而拒絕采樣要求提議分布和原始分布比較接近,這樣采樣率才會比較高,否則這個采樣方法就是低效的,所以往往實際中並不采用這種采樣方法。同樣的,重要性采樣方法也是比較低效的方法。(略去)

MCMC

MCMC是馬爾可夫蒙特卡羅方法,是一種針對高維變量的采樣方法。

MCMC的核心思想是將采樣過程看成一個馬爾可夫鏈,認為第t+1次采樣是依賴於第t次抽取樣本\(x_t\)以及狀態轉移分布\(q(x|x_t)\)。根據馬爾可夫性鏈的收斂特性,我們知道在轉移足夠多此之后最終的狀態將會收斂到一個固定的狀態,我們假定收斂時的分布為\(p(x)\),那么在狀態平穩時進行抽樣得到的樣本就肯定服從與\(p(x)\)分布。

MCMC一般應用的方法有Metropolis-Hastings算法和Gibbs采樣算法。為了快點引入Gibbs Sampling,前者略去。

Gibbs Sampling

假設有一隨機向量\(x = (x_1,x_2,...,x_d)\),其中d表示他有d維,每一維是一隨機變量,且並不是我們常見的相互獨立前提。那么,如果我們已知這個隨機向量的概率分布,我們如何從這個分布中進行采樣呢?

顯然想要從多元分布的聯合概率分布中直接抽樣是相當困難的,而Gibbs Sampling就是一種簡單而且有效的采樣方法。吉布斯采樣的大致步驟如下:

從一個隨機的初始化狀態\(x^{(0)}=[x_1|x_2^{(0)},x_3^{(0)},\cdots,x_d^{(0)}]\)開始,對每個維度單獨進行采樣,其采樣順序大致如下:

\[x_1^{(1)} \thicksim p(x_1|x_2^{(0)},x_3^{(0)},\cdots,x_d^{(0)}) \\x_2^{(1)} \thicksim p(x_2|x_1^{(0)},x_3^{(0)},\cdots,x_d^{(0)}) \\\vdots \\x_d^{(1)} \thicksim p(x_d|x_1^{(0)},x_2^{(0)},\cdots,x_{d-1}^{(0)}) \\\vdots \\x_1^{(t)} \thicksim p(x_1|x_2^{(t-1)},x_3^{(t-1)},\cdots,x_d^{(t-1)}) \\\vdots\\x_{d}^{(t)} \thicksim p(x_d|x_1^{(t-1)},x_2^{(t-1)},\cdots,x_{d-1}^{(t-1)}) \\ \]

遵從上面的采樣步驟,我們最終能夠采樣得到所需要的高維分布的樣本。需要注意的是,迭代的最開始采樣得到的樣本並不是完全滿足所需要的分布的樣本,因為采樣之初采樣的分布是提議分布,一般是均勻分布,而Gibbs Sampling的過程更像是一個單步迭代的過程,這使我想起了EM算法,都是一樣的,一步一步去迭代達到最終結果。

我在網上找到了一個能夠描述這個過程的圖片:

如上圖所示,右圖是我們需要的分布,左邊是迭代的過程,最開始抽樣的點0和1都是均勻分布抽樣得到的,而越到后面,抽樣的點都越滿足我們右邊的分布,所以這個過程可以說明Gibbs Sampling抽樣的過程是可行的。

還有下面這張圖,也差不多:

Coding

Gibbs Sampling我是從一篇圖像合成的論文中看到並有所了解的,文章基於MRF,使用神經網絡去擬合條件分布\(p(x_i|x_{-i})\),其中\(x_{-i}\)表示除了第i個屬性的其他屬性。

具體到圖像中來,\(x_i\)就是第i個位置的像素點的像素值,而\(x_{-i}\)描述的就是除了這個點以外的其他所有點,因此上式的概率分布就是一個條件分布。

使用神經網絡可以擬合出這個分布來,那么如何去生成圖片又是一個問題。

文章給出的解決方案就是Gibbs Sampling,先從隨機噪聲開始,逐像素進行生成,第一次迭代完成將生成一張圖片,那么第二次第三次依次可以使用上一次迭代完前生成的圖片進行迭代生成下一次,當迭代次數足夠多的時候,即我們認為達到了平穩分布,這個時候生成的圖片就是服從該分布的圖片了。

原文參見:

原文鏈接

具體的,我給出下面的代碼:

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils import data
from torchvision import datasets, transforms, utils
from tqdm import tqdm
from PIL import Image
import glob
import random
import cv2 as cv
class MConv(nn.Conv2d):
    '''
    mask_type A or B
    A : the center is zero
    B : the center is not zero
    '''
    def __init__(self,mask_type,*args,**kwargs):
        super(MConv,self).__init__(*args,**kwargs)
        assert mask_type in ["A","B"]
        self.mask_type = mask_type
        self.register_buffer('mask', self.weight.data.clone())
        _,_,h,w = self.weight.size()
        self.mask.fill_(1)
        self.mask[:,:,h//2,w//2 + (mask_type == 'B'):] = 0
        self.mask[:,:,h//2+1:,:] = 0
        
    def forward(self,x):
        self.weight.data *= self.mask
        return super(MaskedConv2d,self).forward(x)
    
    
class DoublePixelCNN(nn.Module):
    def __init__(self,fm,kernel_size = 7,padding = 3):
        super(DoublePixelCNN, self).__init__()
        self.net1 = nn.Sequential(
                MConv('A', 1,  64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
                MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                #nn.Conv2d(fm, 256, 1)
        ) 
        self.net2 = nn.Sequential(
                MConv('A', 1,  64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
                MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                #nn.Conv2d(fm, 256, 1)
        ) 
        
        self.conv1x1 = nn.Conv2d(fm*2, 256, 1)
    def forward(self,x):
        x1 = self.net1(x)
        x2 = self.net2(x.flip(dims = [-1,-2]))
        x = torch.cat([x1,x2.flip(dims = [-1,-2])],dim = 1)
        x = self.conv1x1(x)
        return x

if __name__ == "__main__":
	tr =       data.DataLoader(datasets.MNIST(root="/media/xueaoru/Ubuntu/dataset/data",transform=transforms.ToTensor(),),
                     batch_size=64, shuffle=True, num_workers=12, pin_memory=True)
    net = DoublePixelCNN(128)
    net.cuda()
    sample = torch.rand(64,1,k,k).cuda()
    optimizer = optim.Adam(net.parameters(),lr = 0.0001)
    for epoch in range(1000):
        net.train()
        running_loss = 0.
        for input,_ in tqdm(tr):
            #print(input.size())
            input = input.cuda()
            #target = target.cuda()
            target = (input.data[:,:] * 255).long() # (b,3,h,w)
            # net(input) (b,256,3,h,w)
            loss = F.cross_entropy(net(input), target) # 計算的是每個像素的二分類交叉熵
            running_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print("training loss: {:.8f}".format(running_loss / len(tr)))
        if epoch % 5 == 0:
            torch.save(net.state_dict(),open("./{}.pth".format(epoch),"wb"))
            #sample.fill_(0)
            net.eval()
            with torch.no_grad():
                for t in tqdm(range(300)):
                    for i in range(k):
                        for j in range(k):
                            out = net(sample) # (b,256)
                            probs = F.softmax(out[:, :, i ,j],dim = 1).data # (b,c) = (16,256)
                            sample[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.
                
                utils.save_image(sample, 'sample_{:02d}.png'.format(epoch), nrow=12, padding=0)
    			sample = torch.rand(64,1,k,k).cuda()

由於這個方法采樣時間極其緩慢,所以我生成的圖片尺度比較小,訓練周期也比較短,只是做個demo使用。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM