統計學習方法學習(四)--KNN及kd樹的java實現


K近鄰法

1基本概念

       K近鄰法,是一種基本分類和回歸規則。根據已有的訓練數據集(含有標簽),對於新的實例,根據其最近的k個近鄰的類別,通過多數表決的方式進行預測。

2模型相關

2.1 距離的度量方式

       定義距離

       (1)歐式距離:p=2。

       (2)曼哈頓距離:p=1。

       (3)各坐標的最大值:p=∞。

2.2 K值的選擇

       通常使用交叉驗證法來選取最優的k值。

       k值大小的影響:

       k越小,只有距該點較近的實例才會起作用,學習的近似誤差會較小。但此時又會對這些近鄰的實例很敏感,如果緊鄰點存在噪聲,預測就會出錯,即學習的估計誤差大,泛化能力不好。

       K越大,距該點較遠的實例也會起作用,造成近似誤差增大,使預測發生錯誤。

2.3 k近鄰法的實現:kd樹

  Kd樹是二叉樹。kd樹是一種對K維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構.

  Kd樹是二叉樹, 表示對K維空間的一個划分( partition).構造Kd樹相 當於不斷地用垂直於坐標軸的超平面將k維空間切分, 構成一系列的k維超矩形區域.Kd樹的每個結點對應於一個k維超矩形區域

  其中,創建kd樹時,垂直於坐標軸的超平面垂直的坐標軸選擇是:

  L=(J mod k)+1。其中,j為當前節點的節點深度,k為k維空間(給定實例點的k個維度)。根節點的節點深度為0.此公式可看為:依次循環實例點的k個維所對應的坐標軸。

  Kd樹的節點(分割點)為L維上所有實例點的中位數。

2.4 Kd樹的實現

  別處代碼實現基於其他博客,但是糾正了其中的錯誤,能夠返回前k個近鄰。如果要求最近鄰,只需要將k=1即可。

  

 1 public class BinaryTreeOrder {
 2     
 3     public void preOrder(Node root) {
 4         if(root!= null){
 5             System.out.print(root.toString());
 6             preOrder(root.left);
 7             preOrder(root.right);
 8         }
 9     }
10 }
public class kd_main {

    public static void main(String[] args) {
        List<Node> nodeList=new ArrayList<Node>();
        
        nodeList.add(new Node(new double[]{5,4}));
        nodeList.add(new Node(new double[]{9,6}));
       
        nodeList.add(new Node(new double[]{8,1}));
        nodeList.add(new Node(new double[]{7,2}));
        nodeList.add(new Node(new double[]{2,3}));
        nodeList.add(new Node(new double[]{4,7}));
        nodeList.add(new Node(new double[]{4,3}));
        nodeList.add(new Node(new double[]{1,3}));

        kd_main kdTree=new kd_main();
        Node root=kdTree.buildKDTree(nodeList,0);
        new BinaryTreeOrder().preOrder(root);
        for (Node node : nodeList) {
            System.out.println(node.toString()+"-->"+node.left.toString()+"-->"+node.right.toString());
        }
        System.out.println(root);
        System.out.println(kdTree.searchKNN(root,new Node(new double[]{2.1,3.1}),2));
        System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),1));
        System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),3));
        System.out.println(kdTree.searchKNN(root,new Node(new double[]{6,1}),5));

    }
     
    
    
    /**
     * 構建kd樹  返回根節點
     * @param nodeList
     * @param index
     * @return
     */
    public Node buildKDTree(List<Node> nodeList,int index)
    {
        if(nodeList==null || nodeList.size()==0)
            return null;
        quickSortForMedian(nodeList,index,0,nodeList.size()-1);//中位數排序
        Node root=nodeList.get(nodeList.size()/2);//中位數 當做根節點
        root.dim=index;
        List<Node> leftNodeList=new ArrayList<Node>();//放入左側區域的節點  包括包含與中位數等值的節點-_-
        List<Node> rightNodeList=new ArrayList<Node>();
         
        for(Node node:nodeList)
        {
            if(root!=node)
            {
                if(node.getData(index)<=root.getData(index))
                    leftNodeList.add(node);//左子區域 包含與中位數等值的節點
                else
                    rightNodeList.add(node);
            }
        }
         
        //計算從哪一維度切分
        int newIndex=index+1;//進入下一個維度
        if(newIndex>=root.data.length)
            newIndex=0;//從0維度開始再算
        
        
        root.left=buildKDTree(leftNodeList,newIndex);//添加左右子區域
        root.right=buildKDTree(rightNodeList,newIndex);
         
        if(root.left!=null)
            root.left.parent=root;//添加父指針  
        if(root.right!=null)
            root.right.parent=root;//添加父指針  
        return root;
    }
     
     
    /**
     * 查詢最近鄰
     * @param root kd樹
     * @param q 查詢點
     * @param k
     * @return
     */
    public List<Node> searchKNN(Node root,Node q,int k)
    {
        List<Node> knnList=new ArrayList<Node>();      
        searchBrother(knnList,root,q,k);     
        return knnList;
    }
     
    /**
     * searhchBrother
     * @param knnList 
     * @param k 
     * @param q 
     */
    public void searchBrother(List<Node> knnList, Node root, Node q, int k) {
//         Node almostNNode=root;//近似最近點
         Node leafNNode=searchLeaf(root,q);
         double curD=q.computeDistance(leafNNode);//最近近似點與查詢點的距離 也就是球體的半徑
         leafNNode.distance=curD;
         maintainMaxHeap(knnList,leafNNode,k);
         System.out.println("leaf1"+leafNNode.getData(leafNNode.parent.dim));
         while(leafNNode!=root)
         {
             if (getBrother(leafNNode)!=null) {
                 Node brother=getBrother(leafNNode);
                 System.out.println("brother1"+brother.getData(brother.parent.dim));
                 if(curD>Math.abs(q.getData(leafNNode.parent.dim)-leafNNode.parent.getData(leafNNode.parent.dim))||knnList.size()<k)
                 {
                     //這樣可能在另一個子區域中存在更加近似的點
                     searchBrother(knnList,brother, q, k);
                 }
            }
           System.out.println("leaf2"+leafNNode.getData(leafNNode.parent.dim));
           leafNNode=leafNNode.parent;//返回上一級
           double rootD=q.computeDistance(leafNNode);//最近近似點與查詢點的距離 也就是球體的半徑
           leafNNode.distance=rootD;
           maintainMaxHeap(knnList,leafNNode,k);
          }
     }
            
    
    /**
     * 獲取兄弟節點
     * @param node
     * @return
     */
    public Node getBrother(Node node)
    {
        if(node==node.parent.left)
            return node.parent.right;
        else
            return node.parent.left;
    }
     
    /**
     * 查詢到葉子節點
     * @param root
     * @param q
     * @return
     */
    public Node searchLeaf(Node root,Node q)
    {
        Node leaf=root,next=null;
        int index=0;
        while(leaf.left!=null || leaf.right!=null)
        {
            if(q.getData(index)<leaf.getData(index))
            {
                next=leaf.left;//進入左側
            }else if(q.getData(index)>leaf.getData(index))
            {
                next=leaf.right;
            }else{
                //當取到中位數時  判斷左右子區域哪個更加近
                if(q.computeDistance(leaf.left)<q.computeDistance(leaf.right))
                    next=leaf.left;
                else
                    next=leaf.right;
            }
            if(next==null)
                break;//下一個節點是空時  結束了
            else{
                leaf=next;
                if(++index>=root.data.length)
                    index=0;
            }
        }
         
        return leaf;
    }
     
    /**
     * 維護一個k的最大堆
     * @param listNode
     * @param newNode
     * @param k
     */
    public void maintainMaxHeap(List<Node> listNode,Node newNode,int k)
    {
        if(listNode.size()<k)
        {
            maxHeapFixUp(listNode,newNode);//不足k個堆   直接向上修復
        }else if(newNode.distance<listNode.get(0).distance){
            //比堆頂的要小   還需要向下修復 覆蓋堆頂
            maxHeapFixDown(listNode,newNode);
        }
    }
     
    /**
     * 從上往下修復  將會覆蓋第一個節點
     * @param listNode
     * @param newNode
     */
    private void maxHeapFixDown(List<Node> listNode,Node newNode)
    {
        listNode.set(0, newNode);
        int i=0;
        int j=i*2+1;
        while(j<listNode.size())
        {
            if(j+1<listNode.size() && listNode.get(j).distance<listNode.get(j+1).distance)
                j++;//選出子結點中較大的點,第一個條件是要滿足右子樹不為空
             
            if(listNode.get(i).distance>=listNode.get(j).distance)
                break;
             
            Node t=listNode.get(i);
            listNode.set(i, listNode.get(j));
            listNode.set(j, t);
             
            i=j;
            j=i*2+1;
        }
    }
     
    private void maxHeapFixUp(List<Node> listNode,Node newNode)
    {
        listNode.add(newNode);
        int j=listNode.size()-1;
        int i=(j+1)/2-1;//i是j的parent節點
        while(i>=0)
        {
             
            if(listNode.get(i).distance>=listNode.get(j).distance)
                break;
             
            Node t=listNode.get(i);
            listNode.set(i, listNode.get(j));
            listNode.set(j, t);
             
            j=i;
            i=(j+1)/2-1;
        }
    }
     
     
     
    /**
     * 使用快排進進行一個中位數的查找  完了之后返回的數組size/2即中位數
     * @param nodeList
     * @param index
     * @param left
     * @param right
     */
    @Test
    private void quickSortForMedian(List<Node> nodeList,int index,int left,int right)
    {
        if(left>=right || nodeList.size()<=0)
            return ;
         
        Node kn=nodeList.get(left);
        double k=kn.getData(index);//取得向量指定索引的值
         
        int i=left,j=right;
        
        //控制每一次遍歷的結束條件,i與j相遇
        while(i<j)
        {
            //從右向左找一個小於i處值的值,並填入i的位置
            while(nodeList.get(j).getData(index)>=k && i<j)
                j--;
            nodeList.set(i, nodeList.get(j));
            //從左向右找一個大於i處值的值,並填入j的位置
            while(nodeList.get(i).getData(index)<=k && i<j)
                i++;
            nodeList.set(j, nodeList.get(i));
        }
        
        nodeList.set(i, kn);
        
        
        if(i==nodeList.size()/2)
            return ;//完成中位數的排序了,但並不是完成了所有數的排序,這個終止條件只是保證中位數是正確的。去掉該條件,可以保證在遞歸的作用下,將所有的樹
                    //將所有的數進行排序
        
        else if(i<nodeList.size()/2)
        {
            quickSortForMedian(nodeList,index,i+1,right);//只需要排序右邊就可以了
        }else{
            quickSortForMedian(nodeList,index,left,i-1);//只需要排序左邊就可以了
        }
        
//        for (Node node : nodeList) {
//            System.out.println(node.getData(index));
//        }
    }
}
public class Node implements Comparable<Node>{
    public double[] data;//樹上節點的數據  是一個多維的向量
    public double distance;//與當前查詢點的距離  初始化的時候是沒有的
    public Node left,right,parent;//左右子節點  以及父節點
    public int dim=-1;//維度  建立樹的時候判斷的維度
     
    public Node(double[] data)
    {
        this.data=data;
    }
     
    /**
     * 返回指定索引上的數值
     * @param index
     * @return
     */
    public double getData(int index)
    {
        if(data==null || data.length<=index)
            return Integer.MIN_VALUE;
        return data[index];
    }

    @Override
    public int compareTo(Node o) {
        if(this.distance>o.distance)
            return 1;
        else if(this.distance==o.distance)
            return 0;
        else return -1;
    }
     
    /**
     * 計算距離 這里返回歐式距離
     * @param that
     * @return
     */
    public double computeDistance(Node that)
    {
        if(this.data==null || that.data==null || this.data.length!=that.data.length)
            return Double.MAX_VALUE;//出問題了  距離最遠
        double d=0;
        for(int i=0;i<this.data.length;i++)
        {
            d+=Math.pow(this.data[i]-that.data[i], 2);
        }
         
        return Math.sqrt(d);
    }
     
    public String toString()
    {
        if(data==null || data.length==0)
            return null;
        StringBuilder sb=new StringBuilder();
        for(int i=0;i<data.length;i++)
            sb.append(data[i]+" ");
        sb.append(" d:"+this.distance);
        return sb.toString();
    }
}

   參考文獻:

    [1]李航.統計學習方法

 

  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM