KDTree詳解及java實現


本文內容基於An introductory tutoril on kd-trees

1.KDTree介紹

KDTree根據m維空間中的數據集D構建的二叉樹,能加快常用於最近鄰查找(在加快k-means算法中有應用)。

其節點具有如下屬性(對應第5節中的程序實現):

非葉子節點(不存儲數據):

partitionDimention

用於分割的維度,取值范圍為1,2,…,m

partitionValue

用於分割的值v,當數據點在維度partitionDimention上的值小於v時,被分到左節點,否則分到右節點

left

左節點,使用分到該節點的數據集構建

right

右節點

Max(以及min,是加快最近鄰查找的關鍵,在第3節會講到)

用於構建該節點的數據集(也可以說是該節點的所有葉子節點包含的數據組成的數據集)在各個維度上的最大值組成的d維向量

Min

用於構建該節點的數據集在各個維度上的最小值組成的d維向量

葉子節點:

value

存儲的數據(只存儲一個數據)


private
class Node{ //分割的維度 int partitionDimention; //分割的值 double partitionValue; //如果為非葉子節點,該屬性為空 //否則為數據 double[] value; //是否為葉子 boolean isLeaf=false; //左樹 Node left; //右樹 Node right; //每個維度的最小值 double[] min; //每個維度的最大值 double[] max; }

2.KDTree構建

輸入:數據集D

輸出:KDTree

a.如果D為空,返回空的KDTree
b.新建節點node

c.如果D只有一個數據或D中數據全部相同

   將node標記為葉子節點

d.否則將node標記為非葉子節點

  取各維度上的最大最小值分別生成Max和Min

  遍歷m個維度,找到方差最大的維度作為partitionDimention

  取數據集在partitionDimention維度上排序后的中點作為partitionValue

e.將數據集中在維度partitionDimention上小於partitionValue的划分為D1,

  其他數據點作為D2

  用數據集D1,循環a–e步,生成node的左樹

  用數據集D2,循環a–e步,生成node的右樹

以數據集  (2,3),(5,4),(4,7),(8,1),(7,2),(9,2) 為例子

a.D非空,跳到b
b.新建節點node

c.D有6個數據,且不相同,跳到d

d.標記node為非葉子節點

   在第一個維度上(數組[2,5,4,8,7,9])計算方差得到5.8

   在第二個維度上(數組[3,4,7,1,2,2])計算方差得到3.8

   第一個維度上方差較大,所以partitionDimention=0

   在第一個維度上排序([2,4,5,7,8, 9])取中點7作為分割點,partitionValue=7

 

   兩個維度上的最大值為[9,7]作為Max,兩個維度上的最小值[2,1]作為Min

e.數據集D,在第一個維度上小於7的分入D1:(2,3),(5,4),(4,7)

   大於等於7的分入D2:(8,1),(7,2),(9,2)

   用D1構建左樹,用D2構建右樹

 

     if(data.size()==1){
            node.isLeaf=true;
            node.value=data.get(0);
            return;
        }
        
        //選擇方差最大的維度
        node.partitionDimention=-1;
        double var = -1;
        double tmpvar;
        for(int i=0;i<dimentions;i++){
            tmpvar=UtilZ.variance(data, i);
            if (tmpvar>var){
                var = tmpvar;
                node.partitionDimention = i;
            }
        }
        //如果方差=0,表示所有數據都相同,判定為葉子節點
        if(var==0){
            node.isLeaf=true;
            node.value=data.get(0);
            return;
        }
        
        //選擇分割的值
        node.partitionValue=UtilZ.median(data, node.partitionDimention);
        
        double[][] maxmin=UtilZ.maxmin(data, dimentions);
        node.min = maxmin[0];
        node.max = maxmin[1];
        
        int size = (int)(data.size()*0.55);
        ArrayList<double[]> left = new ArrayList<double[]>(size);
        ArrayList<double[]> right = new ArrayList<double[]>(size);
        
        for(double[] d:data){
            if (d[node.partitionDimention]<node.partitionValue) {
                left.add(d);
            }else {
                right.add(d);
            }
        }
        Node leftnode = new Node();
        Node rightnode = new Node();
        node.left=leftnode;
        node.right=rightnode;
        buildDetail(leftnode, left, dimentions);
        buildDetail(rightnode, right, dimentions);

 

3.KDTree最近鄰查找

    KDTree能實現快速查找的原因:

    一個節點下的所有葉子節點包含的數據所處的范圍可以用一個矩形框住(數據為二維),對應到的屬性就是MaxMin

圖中*(2,3),(5,4),(4,7),(8,1),(7,2),(9,2),可以用[9,7][2,1]框住

此時,判斷方框中是否有和點o (10,4),距離小於t的點

通過以下方法可以初步判斷:

方框到o的距離最小的點,為圖中的正方形表示的點(9,4)

10>9   第一個維度為91<4<7 ,第二個維度為4

當找到一個相對較近的點,得到距離t后,只需要在樹中找到距離比t小的點,此時如果上面的距離大於t,就可以略過這個節點,從而減少很多計算

    查詢步驟:

    輸入:查詢點input

1.  從根節點出發,根據partitionDimentionpartitionValue一路向下直到葉子節點

並一路將路過節點外的其他節點加入棧中(如果進入左節點,就把右節點加入棧中)

用葉子節點上的值作為一個找到的初步最近鄰,記為nearest,和input的距離為distance

2.  若棧為空,返回nearest作為最近鄰

3.  否則從棧中取出節點node

4.  若此節點為葉子節點,計算它和input的距離tmpdis,若tmpdis<input

更新distance=tmpdisnearest=node.value

5.  若此節點為非葉子節點,使用MaxMin構建以下數據點h

h[i]= Max[i],若input[i]>Max[i]

         = Min[i],若input[i]<Min[i]

          = input[i],Min[i]<input[i]< Max[i]

計算ht的距離dis

6.  dis>=t,回到第2

7.  dis<t,根據partitionDimentionpartitionValue一路向下直到葉子節點

並一路將路過節點外的其他節點加入棧中(如果進入左節點,就把右節點加入棧中)

8.  計算它和input的距離tmpdis,若tmpdis<input

更新distance=tmpdisnearest=node.value

進入第2

 

        double[] nearest = null;
        Node node = null;
        double tdis;
        while(stack.size()!=0){
            node = stack.pop();
            if(node.isLeaf){
                 tdis=UtilZ.distance(input, node.value);
                 if(tdis<distance){
                     distance = tdis;
                     nearest = node.value;
                 }
            }else {
                /*
                 * 得到該節點代表的超矩形中點到查找點的最小距離mindistance
                 * 如果mindistance<distance表示有可能在這個節點的子節點上找到更近的點
                 * 否則不可能找到
                 */
                double mindistance = UtilZ.mindistance(input, node.max, node.min);
                if (mindistance<distance) {
                    while(!node.isLeaf){
                        if(input[node.partitionDimention]<node.partitionValue){
                            stack.add(node.right);
                            node=node.left;
                        }else{
                            stack.push(node.left);
                            node=node.right;
                        }
                    }
                    tdis=UtilZ.distance(input, node.value);
                    if(tdis<distance){
                        distance = tdis;
                        nearest = node.value;
                    }
                }
            }
        }
        return nearest;

 

4.KDTree實現

   可以發現當數據量是10000時,kdtree比線性查找要塊134倍

 

   datasize:10000;iteration:100000
   kdtree:468
   linear:63125
   linear/kdtree:134.88247863247864

 

import java.util.ArrayList;
import java.util.Stack;

public class KDTree {
    
    private Node kdtree;
    
    private class Node{
        //分割的維度
        int partitionDimention;
        //分割的值
        double partitionValue;
        //如果為非葉子節點,該屬性為空
        //否則為數據
        double[] value;
        //是否為葉子
        boolean isLeaf=false;
        //左樹
        Node left;
        //右樹
        Node right;
        //每個維度的最小值
        double[] min;
        //每個維度的最大值
        double[] max;
    }
    
    private static class UtilZ{
        /**
         * 計算給定維度的方差
         * @param data 數據
         * @param dimention 維度
         * @return 方差
         */
        static double variance(ArrayList<double[]> data,int dimention){
            double vsum = 0;
            double sum = 0;
            for(double[] d:data){
                sum+=d[dimention];
                vsum+=d[dimention]*d[dimention];
            }
            int n = data.size();
            return vsum/n-Math.pow(sum/n, 2);
        }
        /**
         * 取排序后的中間位置數值
         * @param data 數據
         * @param dimention 維度
         * @return
         */
        static double median(ArrayList<double[]> data,int dimention){
            double[] d =new double[data.size()];
            int i=0;
            for(double[] k:data){
                d[i++]=k[dimention];
            }
            return findPos(d, 0, d.length-1, d.length/2);
        }
        
        static double[][] maxmin(ArrayList<double[]> data,int dimentions){
            double[][] mm = new double[2][dimentions];
            //初始化 第一行為min,第二行為max
            for(int i=0;i<dimentions;i++){
                mm[0][i]=mm[1][i]=data.get(0)[i];
                for(int j=1;j<data.size();j++){
                    double[] d = data.get(j);
                    if(d[i]<mm[0][i]){
                        mm[0][i]=d[i];
                    }else if(d[i]>mm[1][i]){
                        mm[1][i]=d[i];
                    }
                }
            }
            return mm;
        }
        
        static double distance(double[] a,double[] b){
            double sum = 0;
            for(int i=0;i<a.length;i++){
                sum+=Math.pow(a[i]-b[i], 2);
            }
            return sum;
        }
        
        /**
         * 在max和min表示的超矩形中的點和點a的最小距離
         * @param a 點a
         * @param max 超矩形各個維度的最大值
         * @param min 超矩形各個維度的最小值
         * @return 超矩形中的點和點a的最小距離
         */
        static double mindistance(double[] a,double[] max,double[] min){
            double sum = 0;
            for(int i=0;i<a.length;i++){
                if(a[i]>max[i])
                    sum += Math.pow(a[i]-max[i], 2);
                else if (a[i]<min[i]) {
                    sum += Math.pow(min[i]-a[i], 2);
                }
            }
            
            return sum;
        }
        
        /**
         * 使用快速排序,查找排序后位置在point處的值
         * 比Array.sort()后去對應位置值,大約快30%
         * @param data 數據
         * @param low 參加排序的最低點
         * @param high 參加排序的最高點
         * @param point 位置
         * @return
         */
        private static double findPos(double[] data,int low,int high,int point){
            int lowt=low;
            int hight=high;
            double v = data[low];
            ArrayList<Integer> same = new ArrayList<Integer>((int)((high-low)*0.25));
            while(low<high){
                while(low<high&&data[high]>=v){
                    if(data[high]==v){
                        same.add(high);
                    }
                    high--;
                }
                data[low]=data[high];
                while(low<high&&data[low]<v)
                    low++;
                data[high]=data[low];
            }
            data[low]=v;
            int upper = low+same.size();
            if (low<=point&&upper>=point) {
                return v;
            }
            
            if(low>point){
                return findPos(data, lowt, low-1, point);
            }
            
            int i=low+1;
            for(int j:same){
                if(j<=low+same.size())
                    continue;
                while(data[i]==v)
                    i++;
                data[j]=data[i];
                data[i]=v;
                i++;
            }
            
            return findPos(data, low+same.size()+1, hight, point);
        }
    }
    
    private KDTree() {}
    /**
     * 構建樹
     * @param input 輸入
     * @return KDTree樹
     */
    public static KDTree build(double[][] input){
        int n = input.length;
        int m = input[0].length;
        
        ArrayList<double[]> data =new ArrayList<double[]>(n);
        for(int i=0;i<n;i++){
            double[] d = new double[m];
            for(int j=0;j<m;j++)
                d[j]=input[i][j];
            data.add(d);
        }
        
        KDTree tree = new KDTree();
        tree.kdtree = tree.new Node();
        tree.buildDetail(tree.kdtree, data, m);
        
        return tree;
    }
    /**
     * 循環構建樹
     * @param node 節點
     * @param data 數據
     * @param dimentions 數據的維度
     */
    private void buildDetail(Node node,ArrayList<double[]> data,int dimentions){
        if(data.size()==1){
            node.isLeaf=true;
            node.value=data.get(0);
            return;
        }
        
        //選擇方差最大的維度
        node.partitionDimention=-1;
        double var = -1;
        double tmpvar;
        for(int i=0;i<dimentions;i++){
            tmpvar=UtilZ.variance(data, i);
            if (tmpvar>var){
                var = tmpvar;
                node.partitionDimention = i;
            }
        }
        //如果方差=0,表示所有數據都相同,判定為葉子節點
        if(var==0){
            node.isLeaf=true;
            node.value=data.get(0);
            return;
        }
        
        //選擇分割的值
        node.partitionValue=UtilZ.median(data, node.partitionDimention);
        
        double[][] maxmin=UtilZ.maxmin(data, dimentions);
        node.min = maxmin[0];
        node.max = maxmin[1];
        
        int size = (int)(data.size()*0.55);
        ArrayList<double[]> left = new ArrayList<double[]>(size);
        ArrayList<double[]> right = new ArrayList<double[]>(size);
        
        for(double[] d:data){
            if (d[node.partitionDimention]<node.partitionValue) {
                left.add(d);
            }else {
                right.add(d);
            }
        }
        Node leftnode = new Node();
        Node rightnode = new Node();
        node.left=leftnode;
        node.right=rightnode;
        buildDetail(leftnode, left, dimentions);
        buildDetail(rightnode, right, dimentions);
    }
    /**
     * 打印樹,測試時用
     */
    public void print(){
        printRec(kdtree,0);
    }
    
    private void printRec(Node node,int lv){
        if(!node.isLeaf){
            for(int i=0;i<lv;i++)
                System.out.print("--");
            System.out.println(node.partitionDimention+":"+node.partitionValue);
            printRec(node.left,lv+1);
            printRec(node.right,lv+1);
        }else {
            for(int i=0;i<lv;i++)
                System.out.print("--");
            StringBuilder s = new StringBuilder();
            s.append('(');
            for(int i=0;i<node.value.length-1;i++){
                s.append(node.value[i]).append(',');
            }
            s.append(node.value[node.value.length-1]).append(')');
            System.out.println(s);
        }
    }
    
    public double[] query(double[] input){
        Node node = kdtree;
        Stack<Node> stack = new Stack<Node>();
        while(!node.isLeaf){
            if(input[node.partitionDimention]<node.partitionValue){
                stack.add(node.right);
                node=node.left;
            }else{
                stack.push(node.left);
                node=node.right;
            }
        }
        /**
         * 首先按樹一路下來,得到一個想對較近的距離,再找比這個距離更近的點
         */
        double distance = UtilZ.distance(input, node.value);
        double[] nearest=queryRec(input, distance, stack);
        return nearest==null? node.value:nearest;
    }
    
    public double[] queryRec(double[] input,double distance,Stack<Node> stack){
        double[] nearest = null;
        Node node = null;
        double tdis;
        while(stack.size()!=0){
            node = stack.pop();
            if(node.isLeaf){
                 tdis=UtilZ.distance(input, node.value);
                 if(tdis<distance){
                     distance = tdis;
                     nearest = node.value;
                 }
            }else {
                /*
                 * 得到該節點代表的超矩形中點到查找點的最小距離mindistance
                 * 如果mindistance<distance表示有可能在這個節點的子節點上找到更近的點
                 * 否則不可能找到
                 */
                double mindistance = UtilZ.mindistance(input, node.max, node.min);
                if (mindistance<distance) {
                    while(!node.isLeaf){
                        if(input[node.partitionDimention]<node.partitionValue){
                            stack.add(node.right);
                            node=node.left;
                        }else{
                            stack.push(node.left);
                            node=node.right;
                        }
                    }
                    tdis=UtilZ.distance(input, node.value);
                    if(tdis<distance){
                        distance = tdis;
                        nearest = node.value;
                    }
                }
            }
        }
        return nearest;
    }
    
    /**
     * 線性查找,用於和kdtree查詢做對照
     * 1.判斷kdtree實現是否正確
     * 2.比較性能
     * @param input
     * @param data
     * @return
     */
    public static double[] nearest(double[] input,double[][] data){
        double[] nearest=null;
        double dis = Double.MAX_VALUE;
        double tdis;
        for(int i=0;i<data.length;i++){
            tdis = UtilZ.distance(input, data[i]);
            if(tdis<dis){
                dis=tdis;
                nearest = data[i];
            }
        }
        return nearest;
    }
    
    /**
     * 運行100000次,看運行結果是否和線性查找相同
     */
    public static void correct(){
        int count = 100000;
        while(count-->0){
            int num = 100;
            double[][] input = new double[num][2];
            for(int i=0;i<num;i++){
                input[i][0]=Math.random()*10;
                input[i][1]=Math.random()*10;
            }
            double[] query = new double[]{Math.random()*50,Math.random()*50};
            
            KDTree tree=KDTree.build(input);
            double[] result = tree.query(query);
            double[] result1 = nearest(query,input);
            if (result[0]!=result1[0]||result[1]!=result1[1]) {
                System.out.println("wrong");
                break;
            }
        }
    }
    
    public static void performance(int iteration,int datasize){
        int count = iteration;
        
        int num = datasize;
        double[][] input = new double[num][2];
        for(int i=0;i<num;i++){
            input[i][0]=Math.random()*num;
            input[i][1]=Math.random()*num;
        }
        
        KDTree tree=KDTree.build(input);
        
        double[][] query = new double[iteration][2];
        for(int i=0;i<iteration;i++){
            query[i][0]= Math.random()*num*1.5;
            query[i][1]= Math.random()*num*1.5;
        }
        
        long start = System.currentTimeMillis();
        for(int i=0;i<iteration;i++){
            double[] result = tree.query(query[i]);
        }
        long timekdtree = System.currentTimeMillis()-start;
        
        start = System.currentTimeMillis();
        for(int i=0;i<iteration;i++){
            double[] result = nearest(query[i],input);
        }
        long timelinear = System.currentTimeMillis()-start;
        
        System.out.println("datasize:"+datasize+";iteration:"+iteration);
        System.out.println("kdtree:"+timekdtree);
        System.out.println("linear:"+timelinear);
        System.out.println("linear/kdtree:"+(timelinear*1.0/timekdtree));
    }
    
    public static void main(String[] args) {
        //correct();
        performance(100000,10000);
    }
}

 




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM