使用kd-tree加速k-means


0.目錄

 

 

1.前置知識

本文內容基於《Accelerating exact k-means algorithms with geometric reasoning
KDTree
k-means

2.思路介紹

k-means算法在初始化中心點后C通過以下迭代步驟得到局部最優解:
  a.將數據集D中的點x賦給距離最近的中心點
  b.在每個聚類中,重新計算中心點
傳統算法中,a步需要計算n*k個距離(n為D的大小,k為聚類個數),b步需要相加n個數據點
而在KDTree中,每個非葉子節點,都存儲了其包含的數據的數據范圍信息h。

二維空間中的h可以使用矩形來表示
圖中*為點,紅色矩形為數據范圍h 

  a. 如果通過范圍信息,能判斷節點中數據都屬於中心點c,則能省去節點中數據到中心點距離的計算
     如果能判斷h中數據都不屬於某中心點c,則能省去節點中數據到中心點c距離的計算
  b. 當知道節點中數據全部屬於c,能將h中事先加好的統計量直接加到c的統計量中

3.詳述

3.1 確定h的中心點(h中所有數據都離這個中心點近而離其他中心點遠)

 

KDTree的節點中存儲的Max(各維度上的最大值)和Min(各維度上的最小值)確定了節點中數據的范圍
中心點有(c1,c2,...,ck)
a. 判斷是否可能存在
  計算各中心點到h的最小距離(參考KDTree最近鄰查找,第5步) d(ci,h)
  如果存在一個最小距離,則這個ci可能是h的中心點(還需要進一步判斷)
  若存在不止一個最小距離,則h的中心點不存在,需要將h分割為更小(在h的左右樹上)后查找   

正方形表示的點都在h的內部
所以他們到h的最小距離相同,都為0
此h不存在中心點     

b. 進一步判斷,ci是否為中心點
  

L12為c1和c2連線的中位線,h全部落在c1一邊,
所以h中的全部點離c1比離c2近,稱c1優於c2

而對於c1和c3來說,h有一部分落在c1,有一部分落在c3
c1不優於c3
判斷c1是否優於c3:
取向量v=(c3-c1),找到點p屬於h,使<v,p>內積最大
v各維度正負情況(+,-),則p在x軸上盡可能大,y軸上盡可能小,取到p13
p13離c3近,所以c1不優於c3

  如果ci在優於其他點,則可以判定ci即為h的中心點;否則ci不是h的中心點;
  雖然ci不是h的中心點,但是得到的信息,如ci優於c2,能將c2從h的子樹的中心點候選列表中排除

3.2 算法步驟

 

KDTree中每個非葉子節點特殊屬性:
sumOfPoints:m維向量(m是數據的維度),其i維度的值為節點中數據第i維的和
n:節點中數據的個數
輸入:KDTree,C 包括中心點(c1,c2,...,ck)
輸出:CNEW 新的k個中心點
node=KDTree.root
centers=k*m的數組//每行存儲屬於這個中心點的數據的和
datacount=k*1的數組//存儲屬於這個中心點的數據個數
UPDATE(node,C):
IF node為葉子節點
  遍歷計算得到離node最近的節點ct
  centers[t]+=node.value;
  datacount[t]+=1;
  RETURN;

FOR(ci in C)  計算d(ci,node.h)
IF 有多個最小的d(ci,node.h)
  UPDATE(node.left,C);
  UPDATE(node.right,C);
  RETURN;
//假設d(ci,node.h)最小的是ct
CTOVER=[]//存儲劣於ct的
FOR(ci in C(除了ct))  IF(ct 優於 ci) CTOVER.ADD(ci)
IF(LEN(CTOVER)=LEN(C)-1)//ct優於其他的中心點
  centers[t]+=node.sumOfPoints;
  datacount[t]+=node.n;
  RETURN;
CT=(ci in C 且 ci not in CTOVER)//排除比ct差的中心點
UPDATE(node.left,CT);
UPDATE(node.right,CT);
RETURN;

4.java實現

a.用下列matlab方法生成測試數據

#centers為中心點個數,dimention為數據維度,persize為每個中心點包含的數據量
function cdata(centers,dimention,persize) d
=zeros(centers*persize,dimention); sigma=eye(dimention); for i=1:centers mu=randi(20,1,dimention); d(((i-1)*persize+1):i*persize,:)=mvnrnd(mu,sigma,persize); end dlmwrite('d.txt',d,'delimiter','\t','precision','%10.4f') end

b.kdtree

package cc;
import java.util.ArrayList;
import java.util.HashMap;

public class MRKDTree {
    
    private Node mrkdtree;
    
    private class Node{
        //分割的維度
        int partitionDimention;
        //分割的值
        double partitionValue;
        //如果為非葉子節點,該屬性為空
        //否則為數據
        double[] value;
        //是否為葉子
        boolean isLeaf=false;
        //左樹
        Node left;
        //右樹
        Node right;
        //每個維度的最小值
        double[] min;
        //每個維度的最大值
        double[] max;
        
        double[] sumOfPoints;
        int n;
    }
    
    private static class UtilZ{
        /**
         * 計算給定維度的方差
         * @param data 數據
         * @param dimention 維度
         * @return 方差
         */
        static double variance(ArrayList<double[]> data,int dimention){
            double vsum = 0;
            double sum = 0;
            for(double[] d:data){
                sum+=d[dimention];
                vsum+=d[dimention]*d[dimention];
            }
            int n = data.size();
            return vsum/n-Math.pow(sum/n, 2);
        }
        /**
         * 取排序后的中間位置數值
         * @param data 數據
         * @param dimention 維度
         * @return
         */
        static double median(ArrayList<double[]> data,int dimention){
            double[] d =new double[data.size()];
            int i=0;
            for(double[] k:data){
                d[i++]=k[dimention];
            }
            return median(d);
        }
        
        private static double median(double[] a){
            int n=a.length;
            int L = 0;
            int R = n - 1;
            int k = n / 2;
            int i;
            int j;
            while (L < R) {
                double x = a[k];
                i = L;
                j = R;
                do {
                    while (a[i] < x)
                        i++;
                    while (x < a[j])
                        j--;
                    if (i <= j) {
                        double t = a[i];
                        a[i] = a[j];
                        a[j] = t;
                        i++;
                        j--;
                    }
                } while (i <= j);
                if (j < k)
                    L = i;
                if (k < i)
                    R = j;
            }
            return a[k];
        }
        
        static double[][] maxmin(ArrayList<double[]> data,int dimentions){
            double[][] mm = new double[2][dimentions];
            //初始化 第一行為min,第二行為max
            for(int i=0;i<dimentions;i++){
                mm[0][i]=mm[1][i]=data.get(0)[i];
                for(int j=1;j<data.size();j++){
                    double[] d = data.get(j);
                    if(d[i]<mm[0][i]){
                        mm[0][i]=d[i];
                    }else if(d[i]>mm[1][i]){
                        mm[1][i]=d[i];
                    }
                }
            }
            return mm;
        }
        
        static double distance(double[] a,double[] b){
            double sum = 0;
            for(int i=0;i<a.length;i++){
                sum+=Math.pow(a[i]-b[i], 2);
            }
            return sum;
        }
        
        /**
         * 在max和min表示的超矩形中的點和點a的最小距離
         * @param a 點a
         * @param max 超矩形各個維度的最大值
         * @param min 超矩形各個維度的最小值
         * @return 超矩形中的點和點a的最小距離
         */
        static double mindistance(double[] a,double[] max,double[] min){
            double sum = 0;
            for(int i=0;i<a.length;i++){
                if(a[i]>max[i])
                    sum += Math.pow(a[i]-max[i], 2);
                else if (a[i]<min[i]) {
                    sum += Math.pow(min[i]-a[i], 2);
                }
            }
            
            return sum;
        }
        
        public static double[] sumOfPoints(ArrayList<double[]> data,
                int dimentions) {
            double[] res = new double[dimentions];
            for(double[] d:data){
                for(int i=0;i<dimentions;i++){
                    res[i]+=d[i];
                }
            }
            return res;
        }
        /**
         * 判斷centerd是否在h上優於c
         * @param centerd
         * @param c
         * @param max
         * @param min
         * @return
         */
        public static boolean isOver(double[] center, double[] c,
                double[] max, double[] min) {
            double discenter = 0;
            double disc = 0;
            for(int i=0;i<c.length;i++){
                if(c[i]-center[i]>0){
                    disc+=Math.pow(max[i]-c[i],2);
                    discenter+=Math.pow(max[i]-center[i],2);
                }else if(c[i]-center[i]<0) {
                    disc+=Math.pow(min[i]-c[i],2);
                    discenter+=Math.pow(min[i]-center[i],2);
                }
                
            }
            return discenter<disc;
        }
    }
    
    private MRKDTree() {}
    /**
     * 構建樹
     * @param input 輸入
     * @return KDTree樹
     */
    public static MRKDTree build(double[][] input){
        int n = input.length;
        int m = input[0].length;
        
        ArrayList<double[]> data =new ArrayList<double[]>(n);
        for(int i=0;i<n;i++){
            double[] d = new double[m];
            for(int j=0;j<m;j++)
                d[j]=input[i][j];
            data.add(d);
        }
        
        MRKDTree tree = new MRKDTree();
        tree.mrkdtree = tree.new Node();
        tree.buildDetail(tree.mrkdtree, data, m,0);
        
        return tree;
    }
    /**
     * 循環構建樹
     * @param node 節點
     * @param data 數據
     * @param dimentions 數據的維度
     */
    private void buildDetail(Node node,ArrayList<double[]> data,int dimentions,int lv){
        if(data.size()==1){
            node.isLeaf=true;
            node.value=data.get(0);
            return;
        }
        
        //選擇方差最大的維度
        /*
        node.partitionDimention=-1;
        double var = -1;
        double tmpvar;
        for(int i=0;i<dimentions;i++){
            tmpvar=UtilZ.variance(data, i);
            if (tmpvar>var){
                var = tmpvar;
                node.partitionDimention = i;
            }
        }
        //如果方差=0,表示所有數據都相同,判定為葉子節點
        if(var<1e-10){
            node.isLeaf=true;
            node.value=data.get(0);
            return;
        }
        */
        double[][] maxmin=UtilZ.maxmin(data, dimentions);
        
        node.min = maxmin[0];
        node.max = maxmin[1];
        
        //選取方差大的維度,會需要很長時間
        //改成使用選取數據范圍最大的維度
        //這樣構建kdtree的速度會變快,但是在kmean更新中心點會變慢
        boolean isleaf = true;
        for(int i=0;i<node.min.length;i++)
            if(node.min[i]!=node.max[i]){
                isleaf=false;
                break;
            }
        
        if(isleaf){
            node.isLeaf=true;
            node.value=data.get(0);
            return;
        }
        
        node.partitionDimention=-1;
        double diff = -1;
        double tmpdiff;
        for(int i=0;i<dimentions;i++){
            tmpdiff=node.max[i]-node.min[i];
            if (tmpdiff>diff){
                diff = tmpdiff;
                node.partitionDimention = i;
            }
        }
        
        node.sumOfPoints = UtilZ.sumOfPoints(data,dimentions);
        node.n = data.size();
        
        //選擇分割的值
        node.partitionValue=UtilZ.median(data, node.partitionDimention);
        if(node.partitionValue==node.min[node.partitionDimention]){
            node.partitionValue+=1e-5;
        }
        
        int size = (int)(data.size()*0.55);
        ArrayList<double[]> left = new ArrayList<double[]>(size);
        ArrayList<double[]> right = new ArrayList<double[]>(size);
        
        for(double[] d:data){
            if (d[node.partitionDimention]<node.partitionValue) {
                left.add(d);
            }else {
                right.add(d);
            }
        }
        
        Node leftnode = new Node();
        Node rightnode = new Node();
        node.left=leftnode;
        node.right=rightnode;
        buildDetail(leftnode, left, dimentions,lv+1);
        buildDetail(rightnode, right, dimentions,lv+1);
    }
    
    public double[][] updateCentroids(double[][] cs){
        int k = cs.length;
        int m = cs[0].length;
        double[][] entroids = new double[k][m];
        int[] datacount = new int[k];
        HashMap<Integer, double[]> cscopy = new HashMap<Integer, double[]>();
        for(int i=0;i<k;i++)
            cscopy.put(i, cs[i]);
        
        updateCentroidsDetail(mrkdtree,cscopy,entroids,datacount,k,m);
        double[][] csnew = new double[k][m];
        for(int i=0;i<k;i++){
            for(int j=0;j<m;j++){
                csnew[i][j]=entroids[i][j]/datacount[i];
            }
        }
        
        return csnew;
    }
    
    private void updateCentroidsDetail(Node node,
            HashMap<Integer, double[]> cs, double[][] entroids,
            int[] datacount,int k,int m) {
        //如果是葉子節點
        if(node.isLeaf){
            double[] v=node.value;
            double dis=Double.MAX_VALUE;
            double tdis;
            int index = -1;
            //找到所屬的中心點
            for(Integer i: cs.keySet()){
                double[] c = cs.get(i);
                tdis = UtilZ.distance(c, v);
                if(tdis<dis){
                    dis=tdis;
                    index=i;
                }
            }
            
            //更新統計信息
            datacount[index]++;
            for(int i=0;i<m;i++){
                entroids[index][i]+=v[i];
            }
            return;
        }
        
        double[] stack = new double[k];
        int stackpoint = 0;
        int center=0;
        double tdis;
        for(Integer i: cs.keySet()){
            double[] c = cs.get(i);
            tdis = UtilZ.mindistance(c, node.max, node.min);
            if(stackpoint==0){
                stack[stackpoint++]=tdis;
                center=i;
            }else if (tdis<stack[stackpoint-1]) {
                stackpoint=1;
                stack[0]=tdis;
                center=i;
            }else if (tdis==stack[stackpoint-1]) {
                stack[stackpoint++]=tdis;
            }
            
        }
        //stackpoint>1,說明有多個最小值,不存在中心點
        if(stackpoint!=1){
            updateCentroidsDetail(node.left, cs, entroids, datacount, k, m);
            updateCentroidsDetail(node.right, cs, entroids, datacount, k, m);
            return;
        }
        
        HashMap<Integer, Boolean> ctover = new HashMap<Integer, Boolean>();
        double[] centerd = cs.get(center);
        for(Integer i: cs.keySet()){
            if(i==center) continue;
            double[] c = cs.get(i);
            if(UtilZ.isOver(centerd,c,node.max,node.min)){
                ctover.put(i, true);
            }
        }
        
        if(ctover.size()==cs.size()-1){
            //此時中心點即為center,更新信息
            datacount[center]+=node.n;
            for(int i=0;i<m;i++){
                entroids[center][i]+=node.sumOfPoints[i];
            }
            return;
        }
        
        //將其比center差的中心點排除
        HashMap<Integer, double[]> csnew = new HashMap<Integer, double[]>();
        for(Integer i:cs.keySet()){
            if(!ctover.containsKey(i))
                csnew.put(i, cs.get(i));
        }
        
        updateCentroidsDetail(node.left, csnew, entroids, datacount, k, m);
        updateCentroidsDetail(node.right, csnew, entroids, datacount, k, m);
    }
}

c.kmeans

import cc.MRKDTree;


public class KMeans {
    private double[][] centroids;
    
    private KMeans(){}
    
    public static class UtilZ{
        static double[][] randomCentroids(double[][] data,int k){
            double[][] res = new double[k][];
            for(int i=0;i<k;i++){
                res[i] = data[(int)(Math.random()*data.length)];
            }
            return res;
        }
        
        static boolean converged(double[][] c1,double[][] c2,double c){
            for(int i=0;i<c1.length;i++){
                if(changed(c1[i],c2[i])>c){
                    return false;
                }
            }
            return true;
        }
        private static double changed(double[] c1,double[] c2){
            double change=0;
            double total=0;
            for(int i=0;i<c1.length;i++){
                total+=Math.pow(c1[i], 2);
                change+=Math.pow(c1[i]-c2[i], 2);
            }
            return Math.sqrt(change/total);
        }
        
        static double distance(double[] c1,double[] c2){
            double sum = 0;
            for(int i=0;i<c1.length;i++){
                sum+=Math.pow(c1[i]-c2[i], 2);
            }
            return sum;
        }
    }
    public static KMeans build(double[][] input,int k,double c,double[][] cs){
        long start = System.currentTimeMillis();
        MRKDTree tree = MRKDTree.build(input);
        System.out.println("treeConstruct:"+(System.currentTimeMillis()-start));
        
        double[][] csnew = tree.updateCentroids(cs);
        while(!UtilZ.converged(cs, csnew, c)){
            cs=csnew;
            csnew=tree.updateCentroids(cs);
        }
        KMeans km = new KMeans();
        km.centroids=csnew;
        return km;
    }
    
    public static KMeans buildOri(double[][] input,int k,double c,double[][] cs){
        
        double[][] csnew = updateOri(input,cs);
        while(!UtilZ.converged(cs, csnew, c)){
            cs=csnew;
            csnew=updateOri(input,cs);
        }
        KMeans km = new KMeans();
        km.centroids=csnew;
        return km;
    }
    
    
    private static double[][] updateOri(double[][] input,double[][] cs){
        int[] center = new int[input.length];
        for(int i=0;i<input.length;i++){
            double dismin = Double.MAX_VALUE;
            for(int j=0;j<cs.length;j++){
                double dis = UtilZ.distance(input[i], cs[j]);
                if(dis<dismin){
                    dismin=dis;
                    center[i]=j;
                }
            }
        }
        
        double[][] nct =new double[cs.length][cs[0].length];
        int[] datacount = new int[cs.length];
        for(int i=0;i<input.length;i++){
            double[] n = input[i];
            int belong = center[i];
            for(int j=0;j<cs[0].length;j++){
                nct[belong][j]+=n[j];
            }
            datacount[belong]++;
        }
        
        for(int i=0;i<nct.length;i++){
            for(int j=0;j<nct[0].length;j++){
                nct[i][j]/=datacount[i];
            }
        }
        return nct;
    }
    
    public void printCentroids(){
        java.text.DecimalFormat df=new java.text.DecimalFormat("0.00"); 
        for(int i=0;i<centroids.length;i++){
            for(int j=0;j<centroids[i].length;j++)
                System.out.print(df.format(centroids[i][j])+",");
            System.out.println();
        }
    }
}

d.調用

import java.io.BufferedReader;
import java.io.FileReader;

public class Test {
    static void compare(double[][] input){
        double[][] cs = KMeans.UtilZ.randomCentroids(input, 20);
        int t=1;
        long start = System.currentTimeMillis();
        while(t-->0)
            KMeans.build(input, 20, 0.001,cs);
        long kdtree = System.currentTimeMillis()-start;
        t=1;
        start = System.currentTimeMillis();
        while(t-->0)
            KMeans.buildOri(input, 20, 0.001,cs);
        long ori = System.currentTimeMillis()-start;
        
        System.out.println("kdtree:"+kdtree);
        System.out.println("linear:"+ori);
        System.out.println(ori*1.0/kdtree);
    }
    
    public static void main(String[] args) throws Exception{
        BufferedReader reader = new BufferedReader(new FileReader("d.txt"));
        String line=null;
        double[][] input = new double[600000][10];
        int i=0;
        while((line=reader.readLine())!=null){
            String[] numstrs=line.split("\t");
            for(int j=0;j<10;j++)
                input[i][j] = Double.parseDouble(numstrs[j]);
            i++;
        }
        
        compare(input);
    }
}

5.總結

對於數據量較小、中心點較少、維度不多的情景中,使用kd-tree並不能加速,反而比原始的算法更慢,因為kd-tree的構建花費了很長時間;

此時在選擇分割維度的時候不用方差,而用數據范圍,能加快kd-tree 的構建,但會下降一定的kd-tree查詢性能;

當數據量大,中心點多,維度大的情況下或者在x-mean算法中,應該使用方差作為選擇分割維度,此時查詢性能的提升能彌補kd-tee構建的耗時

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM