高斯混合模型GMM與EM算法的Python實現


GMM與EM算法的Python實現

高斯混合模型(GMM)是一種常用的聚類模型,通常我們利用最大期望算法(EM)對高斯混合模型中的參數進行估計。


1. 高斯混合模型(Gaussian Mixture models, GMM)

高斯混合模型(Gaussian Mixture Model,GMM)是一種軟聚類模型。 GMM也可以看作是K-means的推廣,因為GMM不僅是考慮到了數據分布的均值,也考慮到了協方差。和K-means一樣,我們需要提前確定簇的個數。

GMM的基本假設為數據是由幾個不同的高斯分布的隨機變量組合而成。如下圖,我們就是用三個二維高斯分布生成的數據集。

png

在高斯混合模型中,我們需要估計每一個高斯分布的均值與方差。從最大似然估計的角度來說,給定某個有n個樣本的數據集X,假如已知GMM中一共有簇,我們就是要找到k組均值μ1,,μkk組方差σ1,,σk 來最大化以下似然函數L

這里直接計算似然函數比較困難,於是我們引入隱變量(latent variable),這里的隱變量就是每個樣本屬於每一簇的概率。假設W是一個n×k的矩陣,其中 Wi,j 是第 i 個樣本屬於第 j 簇的概率。

在已知W的情況下,我們就很容易計算似然函數LW

將其寫成

其中P(Xi μjσj是樣本Xi在第j個高斯分布中的概率密度函數。

以一維高斯分布為例,

2. 最大期望算法(Expectation–Maximization, EM)

有了隱變量還不夠,我們還需要一個算法來找到最佳的W,從而得到GMM的模型參數。EM算法就是這樣一個算法。

簡單說來,EM算法分兩個步驟。

  • 第一個步驟是E(期望),用來更新隱變量WW;
  • 第二個步驟是M(最大化),用來更新GMM中各高斯分布的參量μjσj

然后重復進行以上兩個步驟,直到達到迭代終止條件。

3. 具體步驟以及Python實現

完整代碼在第4節。

首先,我們先引用一些我們需要用到的庫和函數。

1 import numpy as np 2 import matplotlib.pyplot as plt 3 from matplotlib.patches import Ellipse 4 from scipy.stats import multivariate_normal 5 plt.style.use('seaborn'

接下來,我們生成2000條二維模擬數據,其中400個樣本來自N(μ1,var1),600個來自N(μ2,var2),1000個樣本來自N(μ3,var3)

 1 # 第一簇的數據
 2 num1, mu1, var1 = 400, [0.5, 0.5], [1, 3]  3 X1 = np.random.multivariate_normal(mu1, np.diag(var1), num1)  4 # 第二簇的數據
 5 num2, mu2, var2 = 600, [5.5, 2.5], [2, 2]  6 X2 = np.random.multivariate_normal(mu2, np.diag(var2), num2)  7 # 第三簇的數據
 8 num3, mu3, var3 = 1000, [1, 7], [6, 2]  9 X3 = np.random.multivariate_normal(mu3, np.diag(var3), num3) 10 # 合並在一起
11 X = np.vstack((X1, X2, X3))

 

數據如下圖所示:

1 plt.figure(figsize=(10, 8)) 2 plt.axis([-10, 15, -5, 15]) 3 plt.scatter(X1[:, 0], X1[:, 1], s=5) 4 plt.scatter(X2[:, 0], X2[:, 1], s=5) 5 plt.scatter(X3[:, 0], X3[:, 1], s=5) 6 plt.show()

 

png

3.1 變量初始化

首先要對GMM模型參數以及隱變量進行初始化。通常可以用一些固定的值或者隨機值。

n_clusters 是GMM模型中聚類的個數,和K-Means一樣我們需要提前確定。這里通過觀察可以看出是3。(拓展閱讀:如何確定GMM中聚類的個數?

n_points 是樣本點的個數。

Mu 是每個高斯分布的均值。

Var 是每個高斯分布的方差,為了過程簡便,我們這里假設協方差矩陣都是對角陣。

W 是上面提到的隱變量,也就是每個樣本屬於每一簇的概率,在初始時,我們可以認為每個樣本屬於某一簇的概率都是1/3

Pi 是每一簇的比重,可以根據W求得,在初始時,Pi = [1/3, 1/3, 1/3]

1 n_clusters = 3
2 n_points = len(X) 3 Mu = [[0, -1], [6, 0], [0, 9]] 4 Var = [[1, 1], [1, 1], [1, 1]] 5 Pi = [1 / n_clusters] * 3
6 W = np.ones((n_points, n_clusters)) / n_clusters 7 Pi = W.sum(axis=0) / W.sum()

 

3.2 E步驟

E步驟中,我們的主要目的是更新W。第i個變量屬於第m簇的概率為:

 

根據W,我們就可以更新每一簇的占比πm

 1 def update_W(X, Mu, Var, Pi):  2     n_points, n_clusters = len(X), len(Pi)  3     pdfs = np.zeros(((n_points, n_clusters)))  4     for i in range(n_clusters):  5         pdfs[:, i] = Pi[i] * multivariate_normal.pdf(X, Mu[i], np.diag(Var[i]))  6     W = pdfs / pdfs.sum(axis=1).reshape(-1, 1)  7     return W  8 
 9 
10 def update_Pi(W): 11     Pi = W.sum(axis=0) / W.sum() 12     return Pi

 

以下是計算對數似然函數的logLH以及用來可視化數據的plot_clusters

 1 def logLH(X, Pi, Mu, Var):  2     n_points, n_clusters = len(X), len(Pi)  3     pdfs = np.zeros(((n_points, n_clusters)))  4     for i in range(n_clusters):  5         pdfs[:, i] = Pi[i] * multivariate_normal.pdf(X, Mu[i], np.diag(Var[i]))  6     return np.mean(np.log(pdfs.sum(axis=1)))  7 
 8 
 9 def plot_clusters(X, Mu, Var, Mu_true=None, Var_true=None): 10     colors = ['b', 'g', 'r'] 11     n_clusters = len(Mu) 12     plt.figure(figsize=(10, 8)) 13     plt.axis([-10, 15, -5, 15]) 14     plt.scatter(X[:, 0], X[:, 1], s=5) 15     ax = plt.gca() 16     for i in range(n_clusters): 17         plot_args = {'fc': 'None', 'lw': 2, 'edgecolor': colors[i], 'ls': ':'} 18         ellipse = Ellipse(Mu[i], 3 * Var[i][0], 3 * Var[i][1], **plot_args) 19  ax.add_patch(ellipse) 20     if (Mu_true is not None) & (Var_true is not None): 21         for i in range(n_clusters): 22             plot_args = {'fc': 'None', 'lw': 2, 'edgecolor': colors[i], 'alpha': 0.5} 23             ellipse = Ellipse(Mu_true[i], 3 * Var_true[i][0], 3 * Var_true[i][1], **plot_args) 24  ax.add_patch(ellipse) 25     plt.show()

 

3.2 M步驟

M步驟中,我們需要根據上面一步得到的W來更新均值Mu和方差Var。 MuVar是以W的權重的樣本X的均值和方差。

因為這里的數據是二維的,第m簇的第k個分量的均值,

 

m簇的第k個分量的方差,

 

 

以上迭代公式寫成如下函數update_Muupdate_Var

 1 def update_Mu(X, W):  2     n_clusters = W.shape[1]  3     Mu = np.zeros((n_clusters, 2))  4     for i in range(n_clusters):  5         Mu[i] = np.average(X, axis=0, weights=W[:, i])  6     return Mu  7 
 8 def update_Var(X, Mu, W):  9     n_clusters = W.shape[1] 10     Var = np.zeros((n_clusters, 2)) 11     for i in range(n_clusters): 12         Var[i] = np.average((X - Mu[i]) ** 2, axis=0, weights=W[:, i]) 13     return Var

 

3.3 迭代求解

下面我們進行迭代求解。

圖中實現是真實的高斯分布,虛線是我們估計出的高斯分布。可以看出,經過5次迭代之后,兩者幾乎完全重合。

1 loglh = [] 2 for i in range(5): 3  plot_clusters(X, Mu, Var, [mu1, mu2, mu3], [var1, var2, var3]) 4  loglh.append(logLH(X, Pi, Mu, Var)) 5     W = update_W(X, Mu, Var, Pi) 6     Pi = update_Pi(W) 7     Mu = update_Mu(X, W) 8     print('log-likehood:%.3f'%loglh[-1]) 9     Var = update_Var(X, Mu, W)

 

png

1 log-likehood:-8.054

 

png

1 log-likehood:-4.731

 

png

1 log-likehood:-4.729

 

png

1 log-likehood:-4.728

 

png

1 log-likehood:-4.728

 

4. 完整代碼

 1 import numpy as np  2 import matplotlib.pyplot as plt  3 from matplotlib.patches import Ellipse  4 from scipy.stats import multivariate_normal  5 plt.style.use('seaborn')  6 
 7 # 生成數據
 8 def generate_X(true_Mu, true_Var):  9     # 第一簇的數據
10     num1, mu1, var1 = 400, true_Mu[0], true_Var[0] 11     X1 = np.random.multivariate_normal(mu1, np.diag(var1), num1) 12     # 第二簇的數據
13     num2, mu2, var2 = 600, true_Mu[1], true_Var[1] 14     X2 = np.random.multivariate_normal(mu2, np.diag(var2), num2) 15     # 第三簇的數據
16     num3, mu3, var3 = 1000, true_Mu[2], true_Var[2] 17     X3 = np.random.multivariate_normal(mu3, np.diag(var3), num3) 18     # 合並在一起
19     X = np.vstack((X1, X2, X3)) 20     # 顯示數據
21     plt.figure(figsize=(10, 8)) 22     plt.axis([-10, 15, -5, 15]) 23     plt.scatter(X1[:, 0], X1[:, 1], s=5) 24     plt.scatter(X2[:, 0], X2[:, 1], s=5) 25     plt.scatter(X3[:, 0], X3[:, 1], s=5) 26  plt.show() 27     return X 28 
29 
30 # 更新W
31 def update_W(X, Mu, Var, Pi): 32     n_points, n_clusters = len(X), len(Pi) 33     pdfs = np.zeros(((n_points, n_clusters))) 34     for i in range(n_clusters): 35         pdfs[:, i] = Pi[i] * multivariate_normal.pdf(X, Mu[i], np.diag(Var[i])) 36     W = pdfs / pdfs.sum(axis=1).reshape(-1, 1) 37     return W 38 
39 
40 # 更新pi
41 def update_Pi(W): 42     Pi = W.sum(axis=0) / W.sum() 43     return Pi 44 
45 
46 # 計算log似然函數
47 def logLH(X, Pi, Mu, Var): 48     n_points, n_clusters = len(X), len(Pi) 49     pdfs = np.zeros(((n_points, n_clusters))) 50     for i in range(n_clusters): 51         pdfs[:, i] = Pi[i] * multivariate_normal.pdf(X, Mu[i], np.diag(Var[i])) 52     return np.mean(np.log(pdfs.sum(axis=1))) 53 
54 
55 # 畫出聚類圖像
56 def plot_clusters(X, Mu, Var, Mu_true=None, Var_true=None): 57     colors = ['b','g','r'] 58     n_clusters = len(Mu) 59     plt.figure(figsize=(10,8)) 60     plt.axis([-10,15,-5,15]) 61     plt.scatter(X[:,0], X[:,1], s=5) 62     ax = plt.gca()for i in range(n_clusters): 63         plot_args ={'fc':'None','lw':2,'edgecolor': colors[i],'ls':':'} 64         ellipse =Ellipse(Mu[i],3*Var[i][0],3*Var[i][1],**plot_args) 65         ax.add_patch(ellipse)if(Mu_trueisnotNone)&(Var_trueisnotNone):for i in range(n_clusters): 66             plot_args ={'fc':'None','lw':2,'edgecolor': colors[i],'alpha':0.5} 67             ellipse =Ellipse(Mu_true[i],3*Var_true[i][0],3*Var_true[i][1],**plot_args) 68  ax.add_patch(ellipse) 69     plt.show()# 更新Mudef update_Mu(X, W):
70     n_clusters = W.shape[1]Mu= np.zeros((n_clusters,2))for i in range(n_clusters):Mu[i]= np.average(X, axis=0, weights=W[:, i])returnMu# 更新Vardef update_Var(X,Mu, W):
71     n_clusters = W.shape[1]Var= np.zeros((n_clusters,2))for i in range(n_clusters):Var[i]= np.average((X -Mu[i])**2, axis=0, weights=W[:, i])returnVarif __name__ =='__main__':# 生成數據
72     true_Mu =[[0.5,0.5],[5.5,2.5],[1,7]] 73     true_Var =[[1,3],[2,2],[6,2]] 74     X = generate_X(true_Mu, true_Var)# 初始化
75     n_clusters =3
76     n_points = len(X)Mu=[[0,-1],[6,0],[0,9]]Var=[[1,1],[1,1],[1,1]]Pi=[1/ n_clusters]*3
77     W = np.ones((n_points, n_clusters))/ n_clusters 78     Pi= W.sum(axis=0)/ W.sum()# 迭代
79     loglh =[]for i in range(5): 80  plot_clusters(X,Mu,Var, true_Mu, true_Var) 81  loglh.append(logLH(X,Pi,Mu,Var)) 82         W = update_W(X,Mu,Var,Pi)Pi= update_Pi(W)Mu= update_Mu(X, W)print('log-likehood:%.3f'%loglh[-1])Var= update_Var(X,Mu, W)

 

本教程基於Python 3.6

原創者:u_u | 修改校對:SofaSofa TeamM | 轉自: http://sofasofa.io/tutorials/gmm_em/

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM