聚類算法
李鑫 2014210820 電子系
1、kmeans算法
1.1Kmeans算法理論基礎
K均值算法能夠使聚類域中所有樣品到聚類中心距離平方和最小。其原理為:先取k個初始聚類中心,計算每個樣品到這k個中心的距離,找出最小距離,把樣品歸入最近的聚類中心,修改中心點的值為本類所有樣品的均值,再計算各個樣品到新的聚類中心的距離,重新歸類,修改新的中心點,直到新的聚類中心和上一次聚類中心差距很小時結束。此算法結果受到聚類中心的個數和聚類中心初次選擇影響,也受到樣品的幾個性質及排列次序的影響。如果樣品的幾何性質表明它們能形成幾塊孤立的區域,則算法一般可以收斂。
1.2Kmeans算法實現步驟
①產生二維高斯數據,設置聚類中心數N
②隨機取N個點作為聚類中心。
③計算其余樣品到這N個聚類中心的距離,將他們歸到最近的類,到所有的樣品都歸完類。
④計算各個類樣品的平均值作為該類新的聚類中心,再計算所有樣品到新的聚類中心的距離,把他們歸到最近的類,如此反復,直到聚類中心不再變化為止。
1.3Kmeans算法編程實現
clear all;close all;clc;
% 第一組數據
mu1=[0 0 ]; %均值
S1=[.1 0 ;0 .1]; %協方差
data1=mvnrnd(mu1,S1,100); %產生高斯分布數據
%第二組數據
mu2=[1.25 1.25 ];
S2=[.1 0 ;0 .1];
data2=mvnrnd(mu2,S2,100);
% 第三組數據
mu3=[-1.25 1.25 ];
S3=[.1 0 ;0 .1];
data3=mvnrnd(mu3,S3,100);
% 顯示數據
plot(data1(:,1),data1(:,2),'b+');
hold on;
plot(data2(:,1),data2(:,2),'r+');
plot(data3(:,1),data3(:,2),'g+');
grid on;
% 三類數據合成一個不帶標號的數據類
data=[data1;data2;data3];
N=3;%設置聚類數目
[m,n]=size(data);
pattern=zeros(m,n+1);
center=zeros(N,n);%初始化聚類中心
pattern(:,1:n)=data(:,:);
for x=1:N
center(x,:)=data( randi(300,1),:);%第一次隨機產生聚類中心
end
while 1
distence=zeros(1,N);
num=zeros(1,N);
new_center=zeros(N,n);
for x=1:m
for y=1:N
distence(y)=norm(data(x,:)-center(y,:));%計算到每個類的距離
end
[~, temp]=min(distence);%求最小的距離
pattern(x,n+1)=temp;
end
k=0;
for y=1:N
for x=1:m
if pattern(x,n+1)==y
new_center(y,:)=new_center(y,:)+pattern(x,1:n);
num(y)=num(y)+1;
end
end
new_center(y,:)=new_center(y,:)/num(y);
if norm(new_center(y,:)-center(y,:))<0.1
k=k+1;
end
end
if k==N
break;
else
center=new_center;
end
end
[m, n]=size(pattern);
%最后顯示聚類后的數據
figure;
hold on;
for i=1:m
if pattern(i,n)==1
plot(pattern(i,1),pattern(i,2),'r*');
plot(center(1,1),center(1,2),'ko');
elseif pattern(i,n)==2
plot(pattern(i,1),pattern(i,2),'g*');
plot(center(2,1),center(2,2),'ko');
elseif pattern(i,n)==3
plot(pattern(i,1),pattern(i,2),'b*');
plot(center(3,1),center(3,2),'ko');
elseif pattern(i,n)==4
plot(pattern(i,1),pattern(i,2),'y*');
plot(center(4,1),center(4,2),'ko');
else
plot(pattern(i,1),pattern(i,2),'m*');
plot(center(4,1),center(4,2),'ko');
end
end
grid on;
1.3Kmeans算法測試結果:
a)高斯數對 b)N=2
c) N=3 d)N=4
e)N=5 f)N=6
可以看到聚類數目N對聚類有一定影響,同時在N相同的情況下每次的聚類結果也不完全一樣,說明初始的聚類中心對聚類結果也有一定影響。
2層次聚類算法
層次聚類算法分為合並算法和分裂算法。合並算法會在每一步減少聚類中心的數量,聚類產生的結果來自前一步的兩個聚類的合並;分裂算法與合並算法原理相反,在每一步增加聚類的數量,每一步聚類產生的結果都將是前一步聚類中心分裂得到的。合並算法現將每個樣品自成一類,然后根據類間距離的不同,合並距離小於閾值的類。我用了基於最短距離算法的層次聚類算法,最短距離算法認為,只要兩個類的最小距離小於閾值,就將兩個類合並成一個類。
2.1層次聚類算法實現步驟
①獲得所有樣品特征
②設置閾值
③將所有樣品各分一類,聚類中心等於樣品總個數。
④對所有樣品循環:
找到距離最近的兩類pi,pj,設置距離minDis
若minDis<=T,則合並pi和pj否則退出循環。
2.2層次聚類算法的編程實現
clear all;close all;clc;
% 第一類數據
mu1=[0 0 ]; %均值
S1=[0.1 0 ;0 0.1]; %協方差
data1=mvnrnd(mu1,S1,100); %產生高斯分布數據
%第二類數據
mu2=[1.25 1.25 ];
S2=[0.1 0 ;0 0.1];
data2=mvnrnd(mu2,S2,100);
% 第三個類數據
mu3=[-1.25 1.25 ];
S3=[0.1 0 ;0 0.1];
data3=mvnrnd(mu3,S3,100);
% 顯示數據
plot(data1(:,1),data1(:,2),'b+');
hold on;
plot(data2(:,1),data2(:,2),'r+');
plot(data3(:,1),data3(:,2),'g+');
grid on;
% 三類數據合成一個不帶標號的數據類
data=[data1;data2;data3];
[m,n]=size(data);
patternNum=m;
T=0.1;
pattern=zeros(m,n+1);
for i=1:patternNum
pattern(i,n+1)=i;
pattern(i,1:n)=data(i,:);
end
while 1
minDis=inf;
pi=0;
pj=0;
% 尋找距離最近的兩個類計算最小距離
for i=1:patternNum-1
for j=i+1:patternNum
if(pattern(i,n+1)~=pattern(j,n+1))
tempDis=norm(pattern(i,1:n)-pattern(j,1:n));
if(tempDis<minDis)
minDis=tempDis;
pi=pattern(i,n+1);
pj=pattern(j,n+1);
end
end
end
end
% 距離小於閾值則合並兩個類
if(minDis<=T)
if(pi>pj)
temp=pi;
pi=pj;
pj=temp;
end
for i=1:patternNum
if(pattern(i,n+1)==pi)
pattern(i,n+1)=pi;
elseif(pattern(i,n+1)>pi)
pattern(i,n+1)=pattern(i,n+1)-1;
end
end
else
break;
end
end
disp('ok')
[m, n]=size(pattern);
%最后顯示聚類后的數據
figure;
hold on;
for i=1:m
if pattern(i,n)==1
plot(pattern(i,1),pattern(i,2),'r*');
elseif pattern(i,n)==2
plot(pattern(i,1),pattern(i,2),'g*');
elseif pattern(i,n)==3
plot(pattern(i,1),pattern(i,2),'b*');
elseif pattern(i,n)==4
plot(pattern(i,1),pattern(i,2),'y*');
else
plot(pattern(i,1),pattern(i,2),'m*');
end
end
grid on;
2.3層次聚類算法測試結果:
下圖是產生的高斯數對及當閾值設置為T=0.1的時候的結果:
當T=0.5時的結果:
可見當閾值設的比較大時所有的都將成為一類。所以閾值的設置很重要。
基於最小距離的層次聚算法對類間距要求很高,例如將高斯數對的協方差加大,產生距離比較近的數對時,層次聚類算法就會出現很大問題如下圖:
但是在相同的生成參數下,kmeans卻有很好的效果:
3基於kmeans的圖像分割
Kmeans之前已經講過了,其圖像分割只不過是把之前的高斯數對換成圖像二維像素點,彩色圖像每個像素點有rgb三個分量,灰度圖像只有一個分量。
3.1編程實現
clear;clc;close all;
data=imread('src1.bmp');
imshow(data)
[m,n,c]=size(data);
[mu,pattern]=k_mean_Seg(data,2);
for x=1:m
for y=1:n
if pattern(x,y,1)==1
data(x,y,1)=0;
data(x,y,2)=0;
data(x,y,3)=255;
elseif pattern(x,y,1)==2
data(x,y,1)=0;
data(x,y,2)=255;
data(x,y,3)=0;
elseif pattern(x,y,1)==3
data(x,y,1)=255;
data(x,y,2)=0;
data(x,y,3)=0;
else
data(x,y,1)=255;
data(x,y,2)=255;
data(x,y,3)=0;
end
end
end
figure;
imshow(data);
function [num,mask]=k_mean_Seg(src,k)
src=double(src);
img=src;
src=src(:);
mi=min(src);
src=src-mi+1;
L=length(src);
m=max(src)+1;
hist=zeros(1,m);
histc=zeros(1,m);
for i=1:L
if(src(i)>0)
hist(src(i))=hist(src(i))+1;
end;
end
ind=find(hist);
hl=length(ind);
num=(1:k)*m/(k+1);
while(true)
prenum=num;
for i=1:hl
c=abs(ind(i)-num);
cc=find(c==min(c));
histc(ind(i))=cc(1);
end
for i=1:k,
a=find(histc==i);
num(i)=sum(a.*hist(a))/sum(hist(a));
end
if(num==prenum)
break;
end;
end
L=size(img);
mask=zeros(L);
for i=1:L(1),
for j=1:L(2),
c=abs(img(i,j)-num);
a=find(c==min(c));
mask(i,j)=a(1);
end
end
num=num+mi-1;
3.2結果展示
a)原圖 b)N=2
c)N=3 d)N=4
對於灰度圖像,用data=rgb2gray(data); 轉化成為成為灰度圖像后面的顯示換成如下:
for x=1:m
for y=1:n
if pattern(x,y)==1
data(x,y)=0;
elseif pattern(x,y)==2
data(x,y,1)=80;
elseif pattern(x,y)==3
data(x,y)=180;
else
data(x,y)=255;
end
end
end
a)原圖 b)N=2
c)N=3 d)N=4
4、基於層次聚類算法的圖像分割
層次聚類算法前面已經提到,其計算量遠遠大於kmeans,又由於在MATLAB平台下對多重嵌套的for循環處理速度很慢,並沒有得到運行結果。僅將代碼附如下:
4.1層次聚類算法的圖像分割編程實現
clear;clc;close all;
data=imread('src1.bmp');
data=rgb2gray(data);
imshow(data);
data=double(data);
[m,n]=size(data);
patternNum=m*n;
T=10;
pattern=zeros(m,n);
for x=1:m
for y=1:n
for i=1:patternNum
pattern(x,y)=i;
end
end
end
while 1
minDis=inf;
pi=0;
pj=0;
% 尋找距離最近的兩個類計算最小距離
for x=1:m
for y=1:n
for i=1:m
for j=1:n
if(pattern(x,y)~=pattern(i,j))
tempDis=abs(pattern(x,y)-pattern(i,j));
if(tempDis<minDis)
minDis=tempDis;
pi=pattern(x,y);
pj=pattern(i,j);
end
end
end
end
end
end
% 距離小於閾值則合並兩個類
if(minDis<=T)
if(pi>pj)
temp=pi;
pi=pj;
pj=temp;
end
for i=1:m
for j=1:n
if(pattern(i,j)==pi)
pattern(i,j)=pi;
elseif(pattern(i,j)>pi)
pattern(i,j)=pattern(i,j)-1;
end
end
end
else
break;
end
end
disp('ok')
[m, n]=size(pattern);
%最后顯示聚類后的數據
for x=1:m
for y=1:n
if pattern(x,y)==1
data(x,y,1)=0;
data(x,y,2)=0;
data(x,y,3)=255;
elseif pattern(x,y)==2
data(x,y,1)=0;
data(x,y,2)=255;
data(x,y,3)=0;
elseif pattern(x,y)==3
data(x,y,1)=255;
data(x,y,2)=0;
data(x,y,3)=0;
else
data(x,y,1)=255;
data(x,y,2)=255;
data(x,y,3)=0;
end
end
end
figure;
imshow(data);