GMM算法的matlab程序
在“GMM算法的matlab程序(初步)”這篇文章中已經用matlab程序對iris數據庫進行簡單的實現,下面的程序最終的目的是求准確度。
作者:凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/
1.采用iris數據庫
iris_data.txt
5.1 3.5 1.4 0.2 4.9 3 1.4 0.2 4.7 3.2 1.3 0.2 4.6 3.1 1.5 0.2 5 3.6 1.4 0.2 5.4 3.9 1.7 0.4 4.6 3.4 1.4 0.3 5 3.4 1.5 0.2 4.4 2.9 1.4 0.2 4.9 3.1 1.5 0.1 5.4 3.7 1.5 0.2 4.8 3.4 1.6 0.2 4.8 3 1.4 0.1 4.3 3 1.1 0.1 5.8 4 1.2 0.2 5.7 4.4 1.5 0.4 5.4 3.9 1.3 0.4 5.1 3.5 1.4 0.3 5.7 3.8 1.7 0.3 5.1 3.8 1.5 0.3 5.4 3.4 1.7 0.2 5.1 3.7 1.5 0.4 4.6 3.6 1 0.2 5.1 3.3 1.7 0.5 4.8 3.4 1.9 0.2 5 3 1.6 0.2 5 3.4 1.6 0.4 5.2 3.5 1.5 0.2 5.2 3.4 1.4 0.2 4.7 3.2 1.6 0.2 4.8 3.1 1.6 0.2 5.4 3.4 1.5 0.4 5.2 4.1 1.5 0.1 5.5 4.2 1.4 0.2 4.9 3.1 1.5 0.2 5 3.2 1.2 0.2 5.5 3.5 1.3 0.2 4.9 3.6 1.4 0.1 4.4 3 1.3 0.2 5.1 3.4 1.5 0.2 5 3.5 1.3 0.3 4.5 2.3 1.3 0.3 4.4 3.2 1.3 0.2 5 3.5 1.6 0.6 5.1 3.8 1.9 0.4 4.8 3 1.4 0.3 5.1 3.8 1.6 0.2 4.6 3.2 1.4 0.2 5.3 3.7 1.5 0.2 5 3.3 1.4 0.2 7 3.2 4.7 1.4 6.4 3.2 4.5 1.5 6.9 3.1 4.9 1.5 5.5 2.3 4 1.3 6.5 2.8 4.6 1.5 5.7 2.8 4.5 1.3 6.3 3.3 4.7 1.6 4.9 2.4 3.3 1 6.6 2.9 4.6 1.3 5.2 2.7 3.9 1.4 5 2 3.5 1 5.9 3 4.2 1.5 6 2.2 4 1 6.1 2.9 4.7 1.4 5.6 2.9 3.6 1.3 6.7 3.1 4.4 1.4 5.6 3 4.5 1.5 5.8 2.7 4.1 1 6.2 2.2 4.5 1.5 5.6 2.5 3.9 1.1 5.9 3.2 4.8 1.8 6.1 2.8 4 1.3 6.3 2.5 4.9 1.5 6.1 2.8 4.7 1.2 6.4 2.9 4.3 1.3 6.6 3 4.4 1.4 6.8 2.8 4.8 1.4 6.7 3 5 1.7 6 2.9 4.5 1.5 5.7 2.6 3.5 1 5.5 2.4 3.8 1.1 5.5 2.4 3.7 1 5.8 2.7 3.9 1.2 6 2.7 5.1 1.6 5.4 3 4.5 1.5 6 3.4 4.5 1.6 6.7 3.1 4.7 1.5 6.3 2.3 4.4 1.3 5.6 3 4.1 1.3 5.5 2.5 4 1.3 5.5 2.6 4.4 1.2 6.1 3 4.6 1.4 5.8 2.6 4 1.2 5 2.3 3.3 1 5.6 2.7 4.2 1.3 5.7 3 4.2 1.2 5.7 2.9 4.2 1.3 6.2 2.9 4.3 1.3 5.1 2.5 3 1.1 5.7 2.8 4.1 1.3 6.3 3.3 6 2.5 5.8 2.7 5.1 1.9 7.1 3 5.9 2.1 6.3 2.9 5.6 1.8 6.5 3 5.8 2.2 7.6 3 6.6 2.1 4.9 2.5 4.5 1.7 7.3 2.9 6.3 1.8 6.7 2.5 5.8 1.8 7.2 3.6 6.1 2.5 6.5 3.2 5.1 2 6.4 2.7 5.3 1.9 6.8 3 5.5 2.1 5.7 2.5 5 2 5.8 2.8 5.1 2.4 6.4 3.2 5.3 2.3 6.5 3 5.5 1.8 7.7 3.8 6.7 2.2 7.7 2.6 6.9 2.3 6 2.2 5 1.5 6.9 3.2 5.7 2.3 5.6 2.8 4.9 2 7.7 2.8 6.7 2 6.3 2.7 4.9 1.8 6.7 3.3 5.7 2.1 7.2 3.2 6 1.8 6.2 2.8 4.8 1.8 6.1 3 4.9 1.8 6.4 2.8 5.6 2.1 7.2 3 5.8 1.6 7.4 2.8 6.1 1.9 7.9 3.8 6.4 2 6.4 2.8 5.6 2.2 6.3 2.8 5.1 1.5 6.1 2.6 5.6 1.4 7.7 3 6.1 2.3 6.3 3.4 5.6 2.4 6.4 3.1 5.5 1.8 6 3 4.8 1.8 6.9 3.1 5.4 2.1 6.7 3.1 5.6 2.4 6.9 3.1 5.1 2.3 5.8 2.7 5.1 1.9 6.8 3.2 5.9 2.3 6.7 3.3 5.7 2.5 6.7 3 5.2 2.3 6.3 2.5 5 1.9 6.5 3 5.2 2 6.2 3.4 5.4 2.3 5.9 3 5.1 1.8
iris_id.txt
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2.matlab程序
My_GMM.m
function label_2=My_GMM(K)
%輸入K:聚類數,K個單高斯模型
%輸出label_2:聚的類,para_pi:單高斯權重,para_miu_new:高斯分布參數μ,para_sigma:高斯分布參數sigma
format long
eps=1e-15; %定義迭代終止條件的eps
data=dlmread('E:\www.cnblogs.comkailugaji\data\iris\iris_data.txt');
%----------------------------------------------------------------------------------------------------
%對data做最大-最小歸一化處理
[data_num,~]=size(data);
X=(data-ones(data_num,1)*min(data))./(ones(data_num,1)*(max(data)-min(data)));
[X_num,X_dim]=size(X);
para_sigma=zeros(X_dim,X_dim,K);
%----------------------------------------------------------------------------------------------------
%隨機初始化K個聚類中心
rand_array=randperm(X_num); %產生1~X_num之間整數的隨機排列
center=X(rand_array(1:K),:); %隨機排列取前K個數,在X矩陣中取這K行作為初始聚類中心
%根據上述聚類中心初始化參數
para_miu_new=center; %初始化參數miu
para_pi=ones(1,K)./K; %K類單高斯模型的權重
for k=1:K
para_sigma(:,:,k)=eye(X_dim); %K類單高斯模型的協方差矩陣,初始化為單位陣
end
%歐氏距離,計算(X-para_miu)^2=X^2+para_miu^2-2*X*para_miu',矩陣大小為X_num*K
distant=repmat(sum(X.*X,2),1,K)+repmat(sum(para_miu_new.*para_miu_new,2)',X_num,1)-2*X*para_miu_new';
%返回distant每行最小值所在的下標
[~,label_1]=min(distant,[],2);
for k=1:K
X_k=X(label_1==k,:); %X_k是一個(X_num/K, X_dim)的矩陣,把X矩陣分為K類
para_pi(k)=size(X_k,1)/X_num; %將(每一類數據的個數/X_num)作為para_pi的初始值
para_sigma(:,:,k)=cov(X_k); %para_sigma是一個(X_dim, X_dim)的矩陣,cov(矩陣)求的是每一列之間的協方差
end
%----------------------------------------------------------------------------------------------------
%EM算法
N_pdf=zeros(X_num,K);
while true
para_miu=para_miu_new;
%----------------------------------------------------------------------------------------------------
%E步
%單高斯分布的概率密度函數N_pdf
for k=1:K
X_miu=X-repmat(para_miu(k,:),X_num,1); %X-miu,(X_num, X_dim)的矩陣
sigma_inv=inv(para_sigma(:,:,k)); %sigma的逆矩陣,(X_dim, X_dim)的矩陣//很可能出現奇異矩陣
exp_up=sum((X_miu*sigma_inv).*X_miu,2); %指數的冪,(X-miu)'*sigma^(-1)*(X-miu)
coefficient=(2*pi)^(-X_dim/2)*sqrt(det(sigma_inv)); %高斯分布的概率密度函數e左邊的系數
N_pdf(:,k)=coefficient*exp(-0.5*exp_up);
end
% N_pdf=guass_pdf(X,K,para_miu,para_sigma);
responsivity=N_pdf.*repmat(para_pi,X_num,1); %響應度responsivity的分子,(X_num,K)的矩陣
responsivity=responsivity./repmat(sum(responsivity,2),1,K); %responsivity:在當前模型下第n個觀測數據來自第k個分模型的概率,即分模型k對觀測數據Xn的響應度
%----------------------------------------------------------------------------------------------------
%M步
R_k=sum(responsivity,1); %(1,K)的矩陣,把responsivity每一列求和
%更新參數miu
para_miu_new=diag(1./R_k)*responsivity'*X;
%更新k個參數sigma
for i=1:K
X_miu=X-repmat(para_miu_new(i,:),X_num,1);
para_sigma(:,:,i)=(X_miu'*(diag(responsivity(:,i))*X_miu))/R_k(i);
end
%更新參數pi
para_pi=R_k/sum(R_k);
%----------------------------------------------------------------------------------------------------
%迭代終止條件
if norm(para_miu_new-para_miu)<=eps
break;
end
end
%----------------------------------------------------------------------------------------------------
%聚類
[~,label_2]=max(responsivity,[],2);
succeed.m
function accuracy=succeed(K,id)
%輸入K:聚的類,id:訓練后的聚類結果,N*1的矩陣
N=size(id,1); %樣本個數
p=perms(1:K); %全排列矩陣
p_col=size(p,1); %全排列的行數
new_label=zeros(N,p_col); %聚類結果的所有可能取值,N*p_col
num=zeros(1,p_col); %與真實聚類結果一樣的個數
real_label=dlmread('E:\www.cnblogs.comkailugaji\data\iris\iris_id.txt');
%將訓練結果全排列為N*p_col的矩陣,每一列為一種可能性
for i=1:N
for j=1:p_col
for k=1:K
if id(i)==k
new_label(i,j)=p(j,k)-1;
end
end
end
end
%與真實結果比對,計算精確度
for j=1:p_col
for i=1:N
if new_label(i,j)==real_label(i)
num(j)=num(j)+1;
end
end
end
accuracy=max(num)/N;
3.結果
>> label_1=My_GMM(3); >> accuracy=succeed(3,label_1) accuracy = 0.966666666666667
4.注意
GMM算法我只進行了一次計算准確度,因為有可能會出現奇異矩陣的情況,導致算法出錯,現在我還沒有想出如何解決奇異矩陣的問題,因此只給出了一次循環。望指正。
2020.7.30 奇異問題已初步解決,見評論鏈接。
補充:GMM的Python代碼:upload/GMM.py at master · wl-lei/upload · GitHub
GMM的MATLAB代碼:https://github.com/kailugaji/Gaussian_Mixture_Model_for_Clustering
