kmeans理解


    最近看到Andrew Ng的一篇論文,文中用到了Kmeans和DL結合的思想,突然發現自己對ML最基本的聚類算法都不清楚,於是着重的看了下Kmeans,並在網上找了程序跑了下。

kmeans是unsupervised learning最基本的一個聚類算法,我們可以用它來學習無標簽的特征,其基本思想如下:

    首先給出原始數據{x1,x2,...,xn},這些數據沒有被標記的。

    初始化k個隨機數據u1,u2,...,uk,每一個ui都是一個聚類中心,k就是分為k類,這些xn和uk都是向量。

    根據下面兩個公式迭代就能求出最終所有的聚類中心u。

    formula 1:

                                                                               image

    其中xi是第i個data,uj是第j(1~k)的聚類中心,這個公式的意思就是求出每一個data到k個聚類中心的距離,並求出最小距離,那么數據xi就可以歸到這一類。

    formula 2:

                                                                               image

    這個公式的目的是求出新的聚類中心,由於之前已經求出來每一個data到每一類的聚類中心uj,那么可以在每一類總求出其新的聚類中心(用這一類每一個data到中心的距離之和除以總的data),分別對k類同樣的處理,這樣我們就得到了k個新的聚類中心。

    反復迭代公式一和公式二,知道聚類中心不怎么改變為止。

    我們利用3維數據進行kmeans,代碼如下:

    run_means.m

   1: %%用來kmeans聚類的一個小代碼
   2:  
   3: clear all;
   4: close all;
   5: clc;
   6:  
   7: %第一類數據
   8: mu1=[0 0 0];  %均值
   9: S1=[0.3 0 0;0 0.35 0;0 0 0.3];  %協方差
  10: data1=mvnrnd(mu1,S1,100);   %產生高斯分布數據
  11:  
  12: %%第二類數據
  13: mu2=[1.25 1.25 1.25];
  14: S2=[0.3 0 0;0 0.35 0;0 0 0.3];
  15: data2=mvnrnd(mu2,S2,100);
  16:  
  17: %第三個類數據
  18: mu3=[-1.25 1.25 -1.25];
  19: S3=[0.3 0 0;0 0.35 0;0 0 0.3];
  20: data3=mvnrnd(mu3,S3,100);
  21:  
  22: %顯示數據
  23: plot3(data1(:,1),data1(:,2),data1(:,3),'+');
  24: hold on;
  25: plot3(data2(:,1),data2(:,2),data2(:,3),'r+');
  26: plot3(data3(:,1),data3(:,2),data3(:,3),'g+');
  27: grid on;
  28:  
  29: %三類數據合成一個不帶標號的數據類
  30: data=[data1;data2;data3];   %這里的data是不帶標號的
  31:  
  32: %k-means聚類
  33: [u re]=KMeans(data,3);  %最后產生帶標號的數據,標號在所有數據的最后,意思就是數據再加一維度
  34: [m n]=size(re);
  35:  
  36: %最后顯示聚類后的數據
  37: figure;
  38: hold on;
  39: for i=1:m 
  40:     if re(i,4)==1   
  41:          plot3(re(i,1),re(i,2),re(i,3),'ro'); 
  42:     elseif re(i,4)==2
  43:          plot3(re(i,1),re(i,2),re(i,3),'go'); 
  44:     else 
  45:          plot3(re(i,1),re(i,2),re(i,3),'bo'); 
  46:     end
  47: end
  48: grid on;

 

    KMeans.m

   1: %N是數據一共分多少類
   2: %data是輸入的不帶分類標號的數據
   3: %u是每一類的中心
   4: %re是返回的帶分類標號的數據
   5: function [u re]=KMeans(data,N)   
   6:     [m n]=size(data);   %m是數據個數,n是數據維數
   7:     ma=zeros(n);        %每一維最大的數
   8:     mi=zeros(n);        %每一維最小的數
   9:     u=zeros(N,n);       %隨機初始化,最終迭代到每一類的中心位置
  10:     for i=1:n
  11:        ma(i)=max(data(:,i));    %每一維最大的數
  12:        mi(i)=min(data(:,i));    %每一維最小的數
  13:        for j=1:N
  14:             u(j,i)=ma(i)+(mi(i)-ma(i))*rand();  %隨機初始化,不過還是在每一維[min max]中初始化好些
  15:        end      
  16:     end
  17:    
  18:     while 1
  19:         pre_u=u;            %上一次求得的中心位置
  20:         for i=1:N
  21:             tmp{i}=[];      % 公式一中的x(i)-uj,為公式一實現做准備
  22:             for j=1:m
  23:                 tmp{i}=[tmp{i};data(j,:)-u(i,:)];
  24:             end
  25:         end
  26:         
  27:         quan=zeros(m,N);
  28:         for i=1:m        %公式一的實現
  29:             c=[];        %c 是到每類的距離
  30:             for j=1:N
  31:                 c=[c norm(tmp{j}(i,:))];
  32:             end
  33:             [junk index]=min(c);
  34:             quan(i,index)=norm(tmp{index}(i,:));           
  35:         end
  36:         
  37:         for i=1:N            %公式二的實現
  38:            for j=1:n
  39:                 u(i,j)=sum(quan(:,i).*data(:,j))/sum(quan(:,i));
  40:            end           
  41:         end
  42:         
  43:         if norm(pre_u-u)<0.1  %不斷迭代直到位置不再變化
  44:             break;
  45:         end
  46:     end
  47:     
  48:     re=[];
  49:     for i=1:m
  50:         tmp=[];
  51:         for j=1:N
  52:             tmp=[tmp norm(data(i,:)-u(j,:))];
  53:         end
  54:         [junk index]=min(tmp);
  55:         re=[re;data(i,:) index];
  56:     end
  57:     
  58: end

    原始數據如下所示,分為三類:

                                                               image

    當k取2時,聚成2類:

                                                               image

    當k取3時,聚成3類:

                                                               image


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM