0.目錄
1.前置知識
本文內容基於《Accelerating exact k-means algorithms with geometric reasoning》
KDTree
k-means
2.思路介紹
k-means算法在初始化中心點后C通過以下迭代步驟得到局部最優解:
a.將數據集D中的點x賦給距離最近的中心點
b.在每個聚類中,重新計算中心點
傳統算法中,a步需要計算n*k個距離(n為D的大小,k為聚類個數),b步需要相加n個數據點
而在KDTree中,每個非葉子節點,都存儲了其包含的數據的數據范圍信息h。
![]() |
二維空間中的h可以使用矩形來表示 圖中*為點,紅色矩形為數據范圍h |
a. 如果通過范圍信息,能判斷節點中數據都屬於中心點c,則能省去節點中數據到中心點距離的計算
如果能判斷h中數據都不屬於某中心點c,則能省去節點中數據到中心點c距離的計算
b. 當知道節點中數據全部屬於c,能將h中事先加好的統計量直接加到c的統計量中
3.詳述
3.1 確定h的中心點(h中所有數據都離這個中心點近而離其他中心點遠)
KDTree的節點中存儲的Max(各維度上的最大值)和Min(各維度上的最小值)確定了節點中數據的范圍
中心點有(c1,c2,...,ck)
a. 判斷是否可能存在
計算各中心點到h的最小距離(參考KDTree最近鄰查找,第5步) d(ci,h)
如果存在一個最小距離,則這個ci可能是h的中心點(還需要進一步判斷)
若存在不止一個最小距離,則h的中心點不存在,需要將h分割為更小(在h的左右樹上)后查找
![]() |
正方形表示的點都在h的內部 所以他們到h的最小距離相同,都為0 此h不存在中心點 |
b. 進一步判斷,ci是否為中心點
![]() |
L12為c1和c2連線的中位線,h全部落在c1一邊, 所以h中的全部點離c1比離c2近,稱c1優於c2 而對於c1和c3來說,h有一部分落在c1,有一部分落在c3 c1不優於c3 |
判斷c1是否優於c3: 取向量v=(c3-c1),找到點p屬於h,使<v,p>內積最大 v各維度正負情況(+,-),則p在x軸上盡可能大,y軸上盡可能小,取到p13 p13離c3近,所以c1不優於c3 |
如果ci在優於其他點,則可以判定ci即為h的中心點;否則ci不是h的中心點;
雖然ci不是h的中心點,但是得到的信息,如ci優於c2,能將c2從h的子樹的中心點候選列表中排除
3.2 算法步驟
KDTree中每個非葉子節點特殊屬性: sumOfPoints:m維向量(m是數據的維度),其i維度的值為節點中數據第i維的和 n:節點中數據的個數 |
輸入:KDTree,C 包括中心點(c1,c2,...,ck) |
輸出:CNEW 新的k個中心點 |
node=KDTree.root centers=k*m的數組//每行存儲屬於這個中心點的數據的和 datacount=k*1的數組//存儲屬於這個中心點的數據個數 |
UPDATE(node,C): IF node為葉子節點 遍歷計算得到離node最近的節點ct centers[t]+=node.value; datacount[t]+=1; RETURN; FOR(ci in C) 計算d(ci,node.h) IF 有多個最小的d(ci,node.h) UPDATE(node.left,C); UPDATE(node.right,C); RETURN; //假設d(ci,node.h)最小的是ct CTOVER=[]//存儲劣於ct的 FOR(ci in C(除了ct)) IF(ct 優於 ci) CTOVER.ADD(ci) IF(LEN(CTOVER)=LEN(C)-1)//ct優於其他的中心點 centers[t]+=node.sumOfPoints; datacount[t]+=node.n; RETURN; CT=(ci in C 且 ci not in CTOVER)//排除比ct差的中心點 UPDATE(node.left,CT); UPDATE(node.right,CT); RETURN; |
4.java實現
a.用下列matlab方法生成測試數據
#centers為中心點個數,dimention為數據維度,persize為每個中心點包含的數據量
function cdata(centers,dimention,persize) d=zeros(centers*persize,dimention); sigma=eye(dimention); for i=1:centers mu=randi(20,1,dimention); d(((i-1)*persize+1):i*persize,:)=mvnrnd(mu,sigma,persize); end dlmwrite('d.txt',d,'delimiter','\t','precision','%10.4f') end
b.kdtree
package cc; import java.util.ArrayList; import java.util.HashMap; public class MRKDTree { private Node mrkdtree; private class Node{ //分割的維度 int partitionDimention; //分割的值 double partitionValue; //如果為非葉子節點,該屬性為空 //否則為數據 double[] value; //是否為葉子 boolean isLeaf=false; //左樹 Node left; //右樹 Node right; //每個維度的最小值 double[] min; //每個維度的最大值 double[] max; double[] sumOfPoints; int n; } private static class UtilZ{ /** * 計算給定維度的方差 * @param data 數據 * @param dimention 維度 * @return 方差 */ static double variance(ArrayList<double[]> data,int dimention){ double vsum = 0; double sum = 0; for(double[] d:data){ sum+=d[dimention]; vsum+=d[dimention]*d[dimention]; } int n = data.size(); return vsum/n-Math.pow(sum/n, 2); } /** * 取排序后的中間位置數值 * @param data 數據 * @param dimention 維度 * @return */ static double median(ArrayList<double[]> data,int dimention){ double[] d =new double[data.size()]; int i=0; for(double[] k:data){ d[i++]=k[dimention]; } return median(d); } private static double median(double[] a){ int n=a.length; int L = 0; int R = n - 1; int k = n / 2; int i; int j; while (L < R) { double x = a[k]; i = L; j = R; do { while (a[i] < x) i++; while (x < a[j]) j--; if (i <= j) { double t = a[i]; a[i] = a[j]; a[j] = t; i++; j--; } } while (i <= j); if (j < k) L = i; if (k < i) R = j; } return a[k]; } static double[][] maxmin(ArrayList<double[]> data,int dimentions){ double[][] mm = new double[2][dimentions]; //初始化 第一行為min,第二行為max for(int i=0;i<dimentions;i++){ mm[0][i]=mm[1][i]=data.get(0)[i]; for(int j=1;j<data.size();j++){ double[] d = data.get(j); if(d[i]<mm[0][i]){ mm[0][i]=d[i]; }else if(d[i]>mm[1][i]){ mm[1][i]=d[i]; } } } return mm; } static double distance(double[] a,double[] b){ double sum = 0; for(int i=0;i<a.length;i++){ sum+=Math.pow(a[i]-b[i], 2); } return sum; } /** * 在max和min表示的超矩形中的點和點a的最小距離 * @param a 點a * @param max 超矩形各個維度的最大值 * @param min 超矩形各個維度的最小值 * @return 超矩形中的點和點a的最小距離 */ static double mindistance(double[] a,double[] max,double[] min){ double sum = 0; for(int i=0;i<a.length;i++){ if(a[i]>max[i]) sum += Math.pow(a[i]-max[i], 2); else if (a[i]<min[i]) { sum += Math.pow(min[i]-a[i], 2); } } return sum; } public static double[] sumOfPoints(ArrayList<double[]> data, int dimentions) { double[] res = new double[dimentions]; for(double[] d:data){ for(int i=0;i<dimentions;i++){ res[i]+=d[i]; } } return res; } /** * 判斷centerd是否在h上優於c * @param centerd * @param c * @param max * @param min * @return */ public static boolean isOver(double[] center, double[] c, double[] max, double[] min) { double discenter = 0; double disc = 0; for(int i=0;i<c.length;i++){ if(c[i]-center[i]>0){ disc+=Math.pow(max[i]-c[i],2); discenter+=Math.pow(max[i]-center[i],2); }else if(c[i]-center[i]<0) { disc+=Math.pow(min[i]-c[i],2); discenter+=Math.pow(min[i]-center[i],2); } } return discenter<disc; } } private MRKDTree() {} /** * 構建樹 * @param input 輸入 * @return KDTree樹 */ public static MRKDTree build(double[][] input){ int n = input.length; int m = input[0].length; ArrayList<double[]> data =new ArrayList<double[]>(n); for(int i=0;i<n;i++){ double[] d = new double[m]; for(int j=0;j<m;j++) d[j]=input[i][j]; data.add(d); } MRKDTree tree = new MRKDTree(); tree.mrkdtree = tree.new Node(); tree.buildDetail(tree.mrkdtree, data, m,0); return tree; } /** * 循環構建樹 * @param node 節點 * @param data 數據 * @param dimentions 數據的維度 */ private void buildDetail(Node node,ArrayList<double[]> data,int dimentions,int lv){ if(data.size()==1){ node.isLeaf=true; node.value=data.get(0); return; } //選擇方差最大的維度 /* node.partitionDimention=-1; double var = -1; double tmpvar; for(int i=0;i<dimentions;i++){ tmpvar=UtilZ.variance(data, i); if (tmpvar>var){ var = tmpvar; node.partitionDimention = i; } } //如果方差=0,表示所有數據都相同,判定為葉子節點 if(var<1e-10){ node.isLeaf=true; node.value=data.get(0); return; } */ double[][] maxmin=UtilZ.maxmin(data, dimentions); node.min = maxmin[0]; node.max = maxmin[1]; //選取方差大的維度,會需要很長時間 //改成使用選取數據范圍最大的維度 //這樣構建kdtree的速度會變快,但是在kmean更新中心點會變慢 boolean isleaf = true; for(int i=0;i<node.min.length;i++) if(node.min[i]!=node.max[i]){ isleaf=false; break; } if(isleaf){ node.isLeaf=true; node.value=data.get(0); return; } node.partitionDimention=-1; double diff = -1; double tmpdiff; for(int i=0;i<dimentions;i++){ tmpdiff=node.max[i]-node.min[i]; if (tmpdiff>diff){ diff = tmpdiff; node.partitionDimention = i; } } node.sumOfPoints = UtilZ.sumOfPoints(data,dimentions); node.n = data.size(); //選擇分割的值 node.partitionValue=UtilZ.median(data, node.partitionDimention); if(node.partitionValue==node.min[node.partitionDimention]){ node.partitionValue+=1e-5; } int size = (int)(data.size()*0.55); ArrayList<double[]> left = new ArrayList<double[]>(size); ArrayList<double[]> right = new ArrayList<double[]>(size); for(double[] d:data){ if (d[node.partitionDimention]<node.partitionValue) { left.add(d); }else { right.add(d); } } Node leftnode = new Node(); Node rightnode = new Node(); node.left=leftnode; node.right=rightnode; buildDetail(leftnode, left, dimentions,lv+1); buildDetail(rightnode, right, dimentions,lv+1); } public double[][] updateCentroids(double[][] cs){ int k = cs.length; int m = cs[0].length; double[][] entroids = new double[k][m]; int[] datacount = new int[k]; HashMap<Integer, double[]> cscopy = new HashMap<Integer, double[]>(); for(int i=0;i<k;i++) cscopy.put(i, cs[i]); updateCentroidsDetail(mrkdtree,cscopy,entroids,datacount,k,m); double[][] csnew = new double[k][m]; for(int i=0;i<k;i++){ for(int j=0;j<m;j++){ csnew[i][j]=entroids[i][j]/datacount[i]; } } return csnew; } private void updateCentroidsDetail(Node node, HashMap<Integer, double[]> cs, double[][] entroids, int[] datacount,int k,int m) { //如果是葉子節點 if(node.isLeaf){ double[] v=node.value; double dis=Double.MAX_VALUE; double tdis; int index = -1; //找到所屬的中心點 for(Integer i: cs.keySet()){ double[] c = cs.get(i); tdis = UtilZ.distance(c, v); if(tdis<dis){ dis=tdis; index=i; } } //更新統計信息 datacount[index]++; for(int i=0;i<m;i++){ entroids[index][i]+=v[i]; } return; } double[] stack = new double[k]; int stackpoint = 0; int center=0; double tdis; for(Integer i: cs.keySet()){ double[] c = cs.get(i); tdis = UtilZ.mindistance(c, node.max, node.min); if(stackpoint==0){ stack[stackpoint++]=tdis; center=i; }else if (tdis<stack[stackpoint-1]) { stackpoint=1; stack[0]=tdis; center=i; }else if (tdis==stack[stackpoint-1]) { stack[stackpoint++]=tdis; } } //stackpoint>1,說明有多個最小值,不存在中心點 if(stackpoint!=1){ updateCentroidsDetail(node.left, cs, entroids, datacount, k, m); updateCentroidsDetail(node.right, cs, entroids, datacount, k, m); return; } HashMap<Integer, Boolean> ctover = new HashMap<Integer, Boolean>(); double[] centerd = cs.get(center); for(Integer i: cs.keySet()){ if(i==center) continue; double[] c = cs.get(i); if(UtilZ.isOver(centerd,c,node.max,node.min)){ ctover.put(i, true); } } if(ctover.size()==cs.size()-1){ //此時中心點即為center,更新信息 datacount[center]+=node.n; for(int i=0;i<m;i++){ entroids[center][i]+=node.sumOfPoints[i]; } return; } //將其比center差的中心點排除 HashMap<Integer, double[]> csnew = new HashMap<Integer, double[]>(); for(Integer i:cs.keySet()){ if(!ctover.containsKey(i)) csnew.put(i, cs.get(i)); } updateCentroidsDetail(node.left, csnew, entroids, datacount, k, m); updateCentroidsDetail(node.right, csnew, entroids, datacount, k, m); } }
c.kmeans
import cc.MRKDTree; public class KMeans { private double[][] centroids; private KMeans(){} public static class UtilZ{ static double[][] randomCentroids(double[][] data,int k){ double[][] res = new double[k][]; for(int i=0;i<k;i++){ res[i] = data[(int)(Math.random()*data.length)]; } return res; } static boolean converged(double[][] c1,double[][] c2,double c){ for(int i=0;i<c1.length;i++){ if(changed(c1[i],c2[i])>c){ return false; } } return true; } private static double changed(double[] c1,double[] c2){ double change=0; double total=0; for(int i=0;i<c1.length;i++){ total+=Math.pow(c1[i], 2); change+=Math.pow(c1[i]-c2[i], 2); } return Math.sqrt(change/total); } static double distance(double[] c1,double[] c2){ double sum = 0; for(int i=0;i<c1.length;i++){ sum+=Math.pow(c1[i]-c2[i], 2); } return sum; } } public static KMeans build(double[][] input,int k,double c,double[][] cs){ long start = System.currentTimeMillis(); MRKDTree tree = MRKDTree.build(input); System.out.println("treeConstruct:"+(System.currentTimeMillis()-start)); double[][] csnew = tree.updateCentroids(cs); while(!UtilZ.converged(cs, csnew, c)){ cs=csnew; csnew=tree.updateCentroids(cs); } KMeans km = new KMeans(); km.centroids=csnew; return km; } public static KMeans buildOri(double[][] input,int k,double c,double[][] cs){ double[][] csnew = updateOri(input,cs); while(!UtilZ.converged(cs, csnew, c)){ cs=csnew; csnew=updateOri(input,cs); } KMeans km = new KMeans(); km.centroids=csnew; return km; } private static double[][] updateOri(double[][] input,double[][] cs){ int[] center = new int[input.length]; for(int i=0;i<input.length;i++){ double dismin = Double.MAX_VALUE; for(int j=0;j<cs.length;j++){ double dis = UtilZ.distance(input[i], cs[j]); if(dis<dismin){ dismin=dis; center[i]=j; } } } double[][] nct =new double[cs.length][cs[0].length]; int[] datacount = new int[cs.length]; for(int i=0;i<input.length;i++){ double[] n = input[i]; int belong = center[i]; for(int j=0;j<cs[0].length;j++){ nct[belong][j]+=n[j]; } datacount[belong]++; } for(int i=0;i<nct.length;i++){ for(int j=0;j<nct[0].length;j++){ nct[i][j]/=datacount[i]; } } return nct; } public void printCentroids(){ java.text.DecimalFormat df=new java.text.DecimalFormat("0.00"); for(int i=0;i<centroids.length;i++){ for(int j=0;j<centroids[i].length;j++) System.out.print(df.format(centroids[i][j])+","); System.out.println(); } } }
d.調用
import java.io.BufferedReader; import java.io.FileReader; public class Test { static void compare(double[][] input){ double[][] cs = KMeans.UtilZ.randomCentroids(input, 20); int t=1; long start = System.currentTimeMillis(); while(t-->0) KMeans.build(input, 20, 0.001,cs); long kdtree = System.currentTimeMillis()-start; t=1; start = System.currentTimeMillis(); while(t-->0) KMeans.buildOri(input, 20, 0.001,cs); long ori = System.currentTimeMillis()-start; System.out.println("kdtree:"+kdtree); System.out.println("linear:"+ori); System.out.println(ori*1.0/kdtree); } public static void main(String[] args) throws Exception{ BufferedReader reader = new BufferedReader(new FileReader("d.txt")); String line=null; double[][] input = new double[600000][10]; int i=0; while((line=reader.readLine())!=null){ String[] numstrs=line.split("\t"); for(int j=0;j<10;j++) input[i][j] = Double.parseDouble(numstrs[j]); i++; } compare(input); } }
5.總結
對於數據量較小、中心點較少、維度不多的情景中,使用kd-tree並不能加速,反而比原始的算法更慢,因為kd-tree的構建花費了很長時間;
此時在選擇分割維度的時候不用方差,而用數據范圍,能加快kd-tree 的構建,但會下降一定的kd-tree查詢性能;
當數據量大,中心點多,維度大的情況下或者在x-mean算法中,應該使用方差作為選擇分割維度,此時查詢性能的提升能彌補kd-tee構建的耗時