聚類算法,不是分類算法。
分類算法是給一個數據,然后判斷這個數據屬於已分好的類中的具體哪一類。
聚類算法是給一大堆原始數據,然后通過算法將其中具有相似特征的數據聚為一類。
這里的k-means聚類,是事先給出原始數據所含的類數,然后將含有相似特征的數據聚為一個類中。
所有資料中還是Andrew Ng介紹的明白。
首先給出原始數據{x1,x2,...,xn},這些數據沒有被標記的。
初始化k個隨機數據u1,u2,...,uk。這些xn和uk都是向量。
根據下面兩個公式迭代就能求出最終所有的u,這些u就是最終所有類的中心位置。
公式一:
意思就是求出所有數據和初始化的隨機數據的距離,然后找出距離每個初始數據最近的數據。
公式二:
意思就是求出所有和這個初始數據最近原始數據的距離的均值。
然后不斷迭代兩個公式,直到所有的u都不怎么變化了,就算完成了。
先看看一些結果:
用三個二維高斯分布數據畫出的圖:
通過對沒有標記的原始數據進行kmeans聚類得到的分類,十字是最終迭代位置:
下面是Matlab代碼,這里我把測試數據改為了三維了,函數是可以處理各種維度的。
main.m
clear all; close all; clc; %第一類數據 mu1=[0 0 0]; %均值 S1=[0.3 0 0;0 0.35 0;0 0 0.3]; %協方差 data1=mvnrnd(mu1,S1,100); %產生高斯分布數據 %%第二類數據 mu2=[1.25 1.25 1.25]; S2=[0.3 0 0;0 0.35 0;0 0 0.3]; data2=mvnrnd(mu2,S2,100); %第三個類數據 mu3=[-1.25 1.25 -1.25]; S3=[0.3 0 0;0 0.35 0;0 0 0.3]; data3=mvnrnd(mu3,S3,100); %顯示數據 plot3(data1(:,1),data1(:,2),data1(:,3),'+'); hold on; plot3(data2(:,1),data2(:,2),data2(:,3),'r+'); plot3(data3(:,1),data3(:,2),data3(:,3),'g+'); grid on; %三類數據合成一個不帶標號的數據類 data=[data1;data2;data3]; %這里的data是不帶標號的 %k-means聚類 [u re]=KMeans(data,3); %最后產生帶標號的數據,標號在所有數據的最后,意思就是數據再加一維度 [m n]=size(re); %最后顯示聚類后的數據 figure; hold on; for i=1:m if re(i,4)==1 plot3(re(i,1),re(i,2),re(i,3),'ro'); elseif re(i,4)==2 plot3(re(i,1),re(i,2),re(i,3),'go'); else plot3(re(i,1),re(i,2),re(i,3),'bo'); end end grid on;
KMeans.m
%N是數據一共分多少類 %data是輸入的不帶分類標號的數據 %u是每一類的中心 %re是返回的帶分類標號的數據 function [u re]=KMeans(data,N) [m n]=size(data); %m是數據個數,n是數據維數 ma=zeros(n); %每一維最大的數 mi=zeros(n); %每一維最小的數 u=zeros(N,n); %隨機初始化,最終迭代到每一類的中心位置 for i=1:n ma(i)=max(data(:,i)); %每一維最大的數 mi(i)=min(data(:,i)); %每一維最小的數 for j=1:N u(j,i)=ma(i)+(mi(i)-ma(i))*rand(); %隨機初始化,不過還是在每一維[min max]中初始化好些 end end while 1 pre_u=u; %上一次求得的中心位置 for i=1:N tmp{i}=[]; % 公式一中的x(i)-uj,為公式一實現做准備 for j=1:m tmp{i}=[tmp{i};data(j,:)-u(i,:)]; end end quan=zeros(m,N); for i=1:m %公式一的實現 c=[]; for j=1:N c=[c norm(tmp{j}(i,:))]; end [junk index]=min(c); quan(i,index)=norm(tmp{index}(i,:)); end for i=1:N %公式二的實現 for j=1:n u(i,j)=sum(quan(:,i).*data(:,j))/sum(quan(:,i)); end end if norm(pre_u-u)<0.1 %不斷迭代直到位置不再變化 break; end end re=[]; for i=1:m tmp=[]; for j=1:N tmp=[tmp norm(data(i,:)-u(j,:))]; end [junk index]=min(tmp); re=[re;data(i,:) index]; end end