K近鄰法
1基本概念
K近鄰法,是一種基本分類和回歸規則。根據已有的訓練數據集(含有標簽),對於新的實例,根據其最近的k個近鄰的類別,通過多數表決的方式進行預測。
2模型相關
2.1 距離的度量方式
定義距離
(1)歐式距離:p=2。
(2)曼哈頓距離:p=1。
(3)各坐標的最大值:p=∞。
2.2 K值的選擇
通常使用交叉驗證法來選取最優的k值。
k值大小的影響:
k越小,只有距該點較近的實例才會起作用,學習的近似誤差會較小。但此時又會對這些近鄰的實例很敏感,如果緊鄰點存在噪聲,預測就會出錯,即學習的估計誤差大,泛化能力不好。
K越大,距該點較遠的實例也會起作用,造成近似誤差增大,使預測發生錯誤。
2.3 k近鄰法的實現:kd樹
Kd樹是二叉樹。kd樹是一種對K維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構.
Kd樹是二叉樹, 表示對K維空間的一個划分( partition).構造Kd樹相 當於不斷地用垂直於坐標軸的超平面將k維空間切分, 構成一系列的k維超矩形區域.Kd樹的每個結點對應於一個k維超矩形區域
其中,創建kd樹時,垂直於坐標軸的超平面垂直的坐標軸選擇是:
L=(J mod k)+1。其中,j為當前節點的節點深度,k為k維空間(給定實例點的k個維度)。根節點的節點深度為0.此公式可看為:依次循環實例點的k個維所對應的坐標軸。
Kd樹的節點(分割點)為L維上所有實例點的中位數。
2.4 Kd樹的實現
別處代碼實現基於其他博客,但是糾正了其中的錯誤,能夠返回前k個近鄰。如果要求最近鄰,只需要將k=1即可。
1 public class BinaryTreeOrder { 2 3 public void preOrder(Node root) { 4 if(root!= null){ 5 System.out.print(root.toString()); 6 preOrder(root.left); 7 preOrder(root.right); 8 } 9 } 10 }
public class kd_main { public static void main(String[] args) { List<Node> nodeList=new ArrayList<Node>(); nodeList.add(new Node(new double[]{5,4})); nodeList.add(new Node(new double[]{9,6})); nodeList.add(new Node(new double[]{8,1})); nodeList.add(new Node(new double[]{7,2})); nodeList.add(new Node(new double[]{2,3})); nodeList.add(new Node(new double[]{4,7})); nodeList.add(new Node(new double[]{4,3})); nodeList.add(new Node(new double[]{1,3})); kd_main kdTree=new kd_main(); Node root=kdTree.buildKDTree(nodeList,0); new BinaryTreeOrder().preOrder(root); for (Node node : nodeList) { System.out.println(node.toString()+"-->"+node.left.toString()+"-->"+node.right.toString()); } System.out.println(root); System.out.println(kdTree.searchKNN(root,new Node(new double[]{2.1,3.1}),2)); System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),1)); System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),3)); System.out.println(kdTree.searchKNN(root,new Node(new double[]{6,1}),5)); } /** * 構建kd樹 返回根節點 * @param nodeList * @param index * @return */ public Node buildKDTree(List<Node> nodeList,int index) { if(nodeList==null || nodeList.size()==0) return null; quickSortForMedian(nodeList,index,0,nodeList.size()-1);//中位數排序 Node root=nodeList.get(nodeList.size()/2);//中位數 當做根節點 root.dim=index; List<Node> leftNodeList=new ArrayList<Node>();//放入左側區域的節點 包括包含與中位數等值的節點-_- List<Node> rightNodeList=new ArrayList<Node>(); for(Node node:nodeList) { if(root!=node) { if(node.getData(index)<=root.getData(index)) leftNodeList.add(node);//左子區域 包含與中位數等值的節點 else rightNodeList.add(node); } } //計算從哪一維度切分 int newIndex=index+1;//進入下一個維度 if(newIndex>=root.data.length) newIndex=0;//從0維度開始再算 root.left=buildKDTree(leftNodeList,newIndex);//添加左右子區域 root.right=buildKDTree(rightNodeList,newIndex); if(root.left!=null) root.left.parent=root;//添加父指針 if(root.right!=null) root.right.parent=root;//添加父指針 return root; } /** * 查詢最近鄰 * @param root kd樹 * @param q 查詢點 * @param k * @return */ public List<Node> searchKNN(Node root,Node q,int k) { List<Node> knnList=new ArrayList<Node>(); searchBrother(knnList,root,q,k); return knnList; } /** * searhchBrother * @param knnList * @param k * @param q */ public void searchBrother(List<Node> knnList, Node root, Node q, int k) { // Node almostNNode=root;//近似最近點 Node leafNNode=searchLeaf(root,q); double curD=q.computeDistance(leafNNode);//最近近似點與查詢點的距離 也就是球體的半徑 leafNNode.distance=curD; maintainMaxHeap(knnList,leafNNode,k); System.out.println("leaf1"+leafNNode.getData(leafNNode.parent.dim)); while(leafNNode!=root) { if (getBrother(leafNNode)!=null) { Node brother=getBrother(leafNNode); System.out.println("brother1"+brother.getData(brother.parent.dim)); if(curD>Math.abs(q.getData(leafNNode.parent.dim)-leafNNode.parent.getData(leafNNode.parent.dim))||knnList.size()<k) { //這樣可能在另一個子區域中存在更加近似的點 searchBrother(knnList,brother, q, k); } } System.out.println("leaf2"+leafNNode.getData(leafNNode.parent.dim)); leafNNode=leafNNode.parent;//返回上一級 double rootD=q.computeDistance(leafNNode);//最近近似點與查詢點的距離 也就是球體的半徑 leafNNode.distance=rootD; maintainMaxHeap(knnList,leafNNode,k); } } /** * 獲取兄弟節點 * @param node * @return */ public Node getBrother(Node node) { if(node==node.parent.left) return node.parent.right; else return node.parent.left; } /** * 查詢到葉子節點 * @param root * @param q * @return */ public Node searchLeaf(Node root,Node q) { Node leaf=root,next=null; int index=0; while(leaf.left!=null || leaf.right!=null) { if(q.getData(index)<leaf.getData(index)) { next=leaf.left;//進入左側 }else if(q.getData(index)>leaf.getData(index)) { next=leaf.right; }else{ //當取到中位數時 判斷左右子區域哪個更加近 if(q.computeDistance(leaf.left)<q.computeDistance(leaf.right)) next=leaf.left; else next=leaf.right; } if(next==null) break;//下一個節點是空時 結束了 else{ leaf=next; if(++index>=root.data.length) index=0; } } return leaf; } /** * 維護一個k的最大堆 * @param listNode * @param newNode * @param k */ public void maintainMaxHeap(List<Node> listNode,Node newNode,int k) { if(listNode.size()<k) { maxHeapFixUp(listNode,newNode);//不足k個堆 直接向上修復 }else if(newNode.distance<listNode.get(0).distance){ //比堆頂的要小 還需要向下修復 覆蓋堆頂 maxHeapFixDown(listNode,newNode); } } /** * 從上往下修復 將會覆蓋第一個節點 * @param listNode * @param newNode */ private void maxHeapFixDown(List<Node> listNode,Node newNode) { listNode.set(0, newNode); int i=0; int j=i*2+1; while(j<listNode.size()) { if(j+1<listNode.size() && listNode.get(j).distance<listNode.get(j+1).distance) j++;//選出子結點中較大的點,第一個條件是要滿足右子樹不為空 if(listNode.get(i).distance>=listNode.get(j).distance) break; Node t=listNode.get(i); listNode.set(i, listNode.get(j)); listNode.set(j, t); i=j; j=i*2+1; } } private void maxHeapFixUp(List<Node> listNode,Node newNode) { listNode.add(newNode); int j=listNode.size()-1; int i=(j+1)/2-1;//i是j的parent節點 while(i>=0) { if(listNode.get(i).distance>=listNode.get(j).distance) break; Node t=listNode.get(i); listNode.set(i, listNode.get(j)); listNode.set(j, t); j=i; i=(j+1)/2-1; } } /** * 使用快排進進行一個中位數的查找 完了之后返回的數組size/2即中位數 * @param nodeList * @param index * @param left * @param right */ @Test private void quickSortForMedian(List<Node> nodeList,int index,int left,int right) { if(left>=right || nodeList.size()<=0) return ; Node kn=nodeList.get(left); double k=kn.getData(index);//取得向量指定索引的值 int i=left,j=right; //控制每一次遍歷的結束條件,i與j相遇 while(i<j) { //從右向左找一個小於i處值的值,並填入i的位置 while(nodeList.get(j).getData(index)>=k && i<j) j--; nodeList.set(i, nodeList.get(j)); //從左向右找一個大於i處值的值,並填入j的位置 while(nodeList.get(i).getData(index)<=k && i<j) i++; nodeList.set(j, nodeList.get(i)); } nodeList.set(i, kn); if(i==nodeList.size()/2) return ;//完成中位數的排序了,但並不是完成了所有數的排序,這個終止條件只是保證中位數是正確的。去掉該條件,可以保證在遞歸的作用下,將所有的樹 //將所有的數進行排序 else if(i<nodeList.size()/2) { quickSortForMedian(nodeList,index,i+1,right);//只需要排序右邊就可以了 }else{ quickSortForMedian(nodeList,index,left,i-1);//只需要排序左邊就可以了 } // for (Node node : nodeList) { // System.out.println(node.getData(index)); // } } }
public class Node implements Comparable<Node>{ public double[] data;//樹上節點的數據 是一個多維的向量 public double distance;//與當前查詢點的距離 初始化的時候是沒有的 public Node left,right,parent;//左右子節點 以及父節點 public int dim=-1;//維度 建立樹的時候判斷的維度 public Node(double[] data) { this.data=data; } /** * 返回指定索引上的數值 * @param index * @return */ public double getData(int index) { if(data==null || data.length<=index) return Integer.MIN_VALUE; return data[index]; } @Override public int compareTo(Node o) { if(this.distance>o.distance) return 1; else if(this.distance==o.distance) return 0; else return -1; } /** * 計算距離 這里返回歐式距離 * @param that * @return */ public double computeDistance(Node that) { if(this.data==null || that.data==null || this.data.length!=that.data.length) return Double.MAX_VALUE;//出問題了 距離最遠 double d=0; for(int i=0;i<this.data.length;i++) { d+=Math.pow(this.data[i]-that.data[i], 2); } return Math.sqrt(d); } public String toString() { if(data==null || data.length==0) return null; StringBuilder sb=new StringBuilder(); for(int i=0;i<data.length;i++) sb.append(data[i]+" "); sb.append(" d:"+this.distance); return sb.toString(); } }
參考文獻:
[1]李航.統計學習方法