MRF馬爾可夫隨機場入門
Intro
MRF是一種廣泛應用於圖像分割的模型,當然我看到MRF的時候並不是因為分割,而是在圖像生成領域,有的paper利用MRF模型來生成圖像,因此入門一下MRF,並以分割模型為例記一下代碼。
Model
Target
在圖像分割中,我們的任務是給定一張圖像,輸出每個像素的標簽。因此我們就是要得到在給定圖片特征之下,標簽概率最大化時所對應的標簽。
因此可以這么建模:
其中w表示標簽,f表示圖像特征,求最大后驗概率。
根據貝葉斯理論,上式右邊可以寫成:
其中,P(f)是常量,因為當一張圖片確定之后,P(f)便確定了。因此,上式只取決於分子部分。分子又可以表達為\(P(f,\omega)\),所以我們直接建模的其實是這個部分,計算的也是這個部分,這是與CRF不同的一點(MRF是直接對左邊建模,不分解為右邊,所以沒個樣本都要算一遍后驗概率,然后乘起來最大化,MRF其實是通過對等式右邊分子建模"曲線救國")。
因此,我們的任務中只需要對分子的兩個部分進行定義即可。
Neighbors
像素Neighbors的定義很簡單,就是這個像素周圍的其他像素。
舉例而言,下圖分別是中心點像素的四鄰域和八鄰域。
Hammersley-Clifford Theorem
定理的內容為:
如果一個分布\(P(x)>0\)滿足無向圖\(G\)中的局部馬爾可夫性質,當且僅當\(P(x)\)可以表示為一系列定義在最大團上的非負函數的乘積形式,即:
其中\(C\)為\(G\)中最大團集合,也就是所有的最大團組成的集合,\(\phi(x_c) \ge 0\)是定義在團\(c\)上的勢能函數,Z是配分函數,用來將乘積歸一化為概率的形式。
無向圖模型與有向圖模型的一個重要區別就是配分函數Z。
Hammersley-Clifford Theorem表明,無向圖模型和吉布斯分布是一致的,所以將\(P(\omega)\)定義下式:
其中,Z作為normalization項,\(Z = \sum exp(-U(\omega))\),U定義為勢能,而等號最右邊將U變成了V的求和,在后面我們會說到,這里其實是每個原子團的勢能的求和。
Clique
Clique就是我們上面提到的“團”的概念。集合\(c\)是\(S\)的原子團當且僅當c中的每個元素都與該集合中的其他元素相鄰。那么Clique就是所有\(c\)的並集。
舉例而言:
一個像素的四鄰域及他自己組成的集合的原子團可以分為singleton和doubleton如圖所示。
Clique Potential
翻譯過來就是勢能,用\(V(w)\)表示,描述的是一個Clique的能量。
那么,一個像素的領域的勢能就是每個團的能量的和。
其中c表示原子團,c表示Clique,V是如何定義的呢?
在圖像分割中,可以以一階團為例,
到這里,\(P(\omega)\)的所有變量解釋完了,下一步是計算\(P(f|\omega)\)
\(P(f|\omega)\)的計算
\(P(f|\omega)\)被認為是服從高斯分布的,也就是說,如果我們知道了這個像素的標簽是什么,那么他的像素值應該服從這個標簽下的條件概率的高斯分布。其實他服從高斯分布還是很好理解的,我們已知這個像素點的label比如說是A,那么我們去統計一下所有標簽是A的點的像素值的均值和方差,顯然以這個均值和方差為參數的高斯分布更加契合這里的條件分布。
計算每個類別的像素均值和方差,帶入公式,即得條件概率。
最后,就是最大化\(P(\omega)P(f|\omega)\),以對數形式轉化為求和的形式去優化,最大化\(log(P(\omega)) + log(P(f|\omega))\).
Coding
import numpy as np
import cv2 as cv
import copy
class MRF():
def __init__(self,img,max_iter = 100,num_clusters = 5,init_func = None,beta = 8e-4):
self.max_iter = max_iter
self.kernels = np.zeros(shape = (8,3,3))
self.beta = beta
self.num_clusters = num_clusters
for i in range(9):
if i < 4:
self.kernels[i,i//3,i%3] = 1
elif i > 4:
self.kernels[i-1,i//3,i%3] = 1
self.img = img
if init_func is None:
self.labels = np.random.randint(low = 1,high = num_clusters + 1,size = img.shape,dtype = np.uint8)
def __call__(self):
img = self.img.reshape((-1,))
for iter in range(self.max_iter):
p1 = np.zeros(shape = (self.num_clusters,self.img.shape[0] * self.img.shape[1]))
for cluster_idx in range(self.num_clusters):
temp = np.zeros(shape = (self.img.shape))
for i in range(8):
res = cv.filter2D(self.labels,-1,self.kernels[i,:,:])
temp[(res == (cluster_idx + 1))] -= self.beta
temp[(res != (cluster_idx + 1))] += self.beta
temp = np.exp(-temp)
p1[cluster_idx,:] = temp.reshape((-1,))
p1 = p1 / np.sum(p1)
p1[p1 == 0] = 1e-3
mu = np.zeros(shape = (self.num_clusters,))
sigma = np.zeros(shape = (self.num_clusters,))
for i in range(self.num_clusters):
#mu[i] = np.mean(self.img[self.labels == (i+1)])
data = self.img[self.labels == (i+1)]
if np.sum(data) > 0:
mu[i] = np.mean(data)
sigma[i] = np.var(data)
else:
mu[i]= 0
sigma[i] = 1
#print(sigma[i])
#sigma[sigma == 0] = 1e-3
p2 = np.zeros(shape = (self.num_clusters,self.img.shape[0] * self.img.shape[1]))
for i in range(self.img.shape[0] * self.img.shape[1]):
for j in range(self.num_clusters):
#print(sigma[j])
p2[j,i] = -np.log(np.sqrt(2*np.pi)*sigma[j]) -(img[i]-mu[j])**2/2/sigma[j];
self.labels = np.argmax(np.log(p1) + p2,axis = 0) + 1
self.labels = np.reshape(self.labels,self.img.shape).astype(np.uint8)
print("-----------start-----------")
print(p1)
print("-" * 20)
print(p2)
print("----------end------------")
#print("iter {} over!".format(iter))
#self.show()
#print(self.labels)
def show(self):
h,w = self.img.shape
show_img = np.zeros(shape = (h,w,3),dtype = np.uint8)
show_img[self.labels == 1,:] = (0,255,255)
show_img[self.labels == 2,:] = (220,20,60)
show_img[self.labels == 3,:] = (65,105,225)
show_img[self.labels == 4,:] = (50,205,50)
#img = self.labels / (self.num_clusters) * 255
cv.imshow("res",show_img)
cv.waitKey(0)
if __name__ == "__main__":
img = cv.imread("/home/xueaoru/圖片/0.jpg")
img = cv.cvtColor(img,cv.COLOR_BGR2GRAY)
img = img/255.
#img = np.random.rand(64,64)
#img = cv.resize(img,(256,256))
mrf = MRF(img = img,max_iter = 20,num_clusters = 2)
mrf()
mrf.show()
#print(mrf.kernels)
Input:
Output(num_clusters = 4):
Output(num_clusters = 2):