介紹摘自李航《統計學習方法》
EM算法
EM算法是一種迭代算法,1977年由Dempster等人總結提出,用於含有隱變量(hidden variable)的概率模型參數的極大似然估計,或極大后驗概率估計。EM算法的每次迭代由兩步組成:E步,求期望(expectation);M步,求極大(maximization)。所以這一算法稱為期望極大算法(expectation maximization algorithm),簡稱EM算法。本章首先敘述EM算法,然后討論EM算法的收斂性;作為EM算法的應用,介紹高斯混合模型的學習;最后敘述EM算法的推廣——GEM算法。
將觀測數據表示為Y=(Y1,Y2,…,Yn)T,未觀測數據表示為Z=(Z1,Z2,…,Zn)T,則觀測數據的似然函數為
即
考慮求模型參數=(
Π,p,q)的極大似然估計,即
這個問題沒有解析解,只有通過迭代的方法求解。EM算法就是可以用於求解這個問題的一種迭代算法。下面給出針對以上問題的EM算法,其推導過程省略。
EM算法首先選取參數的初值,記作(0)=(
Π(0),p(0),q(0)),然后通過下面的步驟迭代計算參數的估計值,直至收斂為止。第i次迭代參數的估計值為
(i)=(
(i),p(i),q(i))。EM算法的第i+1次迭代如下。
E步:計算在模型參數Π(i),p(i),q(i)下觀測數據yj來自擲硬幣B的概率
M步:計算模型參數的新估計值
EM算法在高斯混合模型學習中的應用
EM算法的一個重要應用是高斯混合模型的參數估計。高斯混合模型應用廣泛,在許多情況下,EM算法是學習高斯混合模型(Gaussian misture model)的有效方法。
定義9.2(高斯混合模型) 高斯混合模型是指具有如下形式的概率分布模型:
高斯混合模型參數估計的EM算法
1.明確隱變量,寫出完全數據的對數似然函數
可以設想觀測數據yj,j=1,2,…,N,是這樣產生的:首先依概率ak選擇第k個高斯分布分模型Ø(y|θk);然后依第k個分模型的概率分布Ø(y|
θk)生成觀測數據yj。這時觀測數據yj,j=1,2,…,N,是已知的;反映觀測數據yj來自第k個分模型的數據是未知的,k=1,2,…,K,以隱變量
γjk表示,其定義如下:
有了觀測數據yj及未觀測數據γjk,那么完全數據是
於是,可以寫出完全數據的似然函數:
2.EM算法的E步:確定Q函數
3.確定EM算法的M步
迭代的M步是求函數Q(θ,θ
(i))對
的極大值,即求新一輪迭代的模型參數:
重復以上計算,直到對數似然函數值不再有明顯的變化為止。
1 # coding:utf-8 2 import numpy as np 3 4 def qq(y,alpha,mu,sigma,K,gama):#計算Q函數 5 gsum=[] 6 n=len(y) 7 for k in range(K): 8 gsum.append(np.sum([gama[j,k] for j in range(n)])) 9 return np.sum([g*np.log(ak) for g,ak in zip(gsum,alpha)])+\ 10 np.sum([[np.sum(gama[j,k]*(np.log(1/np.sqrt(2*np.pi))-np.log(np.sqrt(sigma[k]))-1/2/sigma[k]*(y[j]-mu[k])**2)) 11 for j in range(n)] for k in range(K)]) #《統計學習方法》中公式9.29有誤 12 13 def phi(mu,sigma,y): #計算phi 14 return 1/(np.sqrt(2*np.pi*sigma)*np.exp(-(y-mu)**2/2/sigma)) 15 16 def gama(alpha,mu,sigma,i,k): #計算gama 17 sumak=np.sum([[a*phi(m,s,i)] for a,m,s in zip(alpha,mu,sigma)]) 18 return alpha[k]*phi(mu[k],sigma[k],i)/sumak 19 20 def dataN(length,k):#生成數據 21 y=[np.random.normal(5*j,j+5,length/k) for j in range(k)] 22 return y 23 24 def EM(y,K,iter=1000): #EM算法 25 n = len(y) 26 sigma=[10]*K 27 mu=range(K) 28 alpha=np.ones(K) 29 qqold,qqnew=0,0 30 for it in range(iter): 31 gama2=np.ones((n,K)) 32 for k in range(K): 33 for i in range(n): 34 gama2[i,k]=gama(alpha,mu,sigma,y[i],k) 35 for k in range(K): 36 sum_gama=np.sum([gama2[j,k] for j in range(n)]) 37 mu[k]=np.sum([gama2[j,k]*y[j] for j in range(n)])/sum_gama 38 sigma[k]=np.sum([gama2[j,k]*(y[j]-mu[k])**2 for j in range(n)])/sum_gama 39 alpha[k]=sum_gama/n 40 qqnew=qq(y,alpha,mu,sigma,K,gama2) 41 if abs(qqold-qqnew)<0.000001: 42 break 43 qqold=qqnew 44 return alpha,mu,sigma 45 46 N = 500 47 k=2 48 data=dataN(N,k) 49 y=np.reshape(data,(1,N)) 50 a,b,c = EM(y[0], k) 51 print a,b,c 52 # iter=180 53 #[ 0.57217609 0.42782391] [4.1472879054766887, 0.72534713118155769] [44.114682884921415, 24.676116557533351] 54 55 sigma = 6 #網上的數據 56 miu1 = 40 57 miu2 = 20 58 X = np.zeros((1, N)) 59 for i in xrange(N): 60 if np.random.random() > 0.5: 61 X[0, i] = np.random.randn() * sigma + miu1 62 else: 63 X[0, i] = np.random.randn() * sigma + miu2 64 a,b,c = EM(X[0], k) 65 print a,b,c 66 # iter=114 67 #[ 0.44935959 0.55064041] [40.561782615819361, 21.444533254494189] [33.374144230703514, 51.459622219329155]