https://blog.csdn.net/App_12062011/article/details/51986805
一:kd樹構建
以二維平面點((x,y))的集合(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)為例結合下圖來說明k-d tree的構建過程。
(一)構建步驟
1.構建根節點時,此時的切分維度為(x),如上點集合在(x)維從小到大排序為:
(2,3),(4,7),(5,4),(7,2),(8,1),(9,6);
其中中位數為7,選擇中值(7,2)。(注:2,4,5,7,8,9在數學中的中值為(5 + 7)/2=6,但因該算法的中值需在點集合之內,所以本文中值計算用的是len(points)//2=3, points[3]=(7,2))
2.(2,3),(4,7),(5,4)掛在(7,2)節點的左子樹,(8,1),(9,6)掛在(7,2)節點的右子樹。
3.構建(7,2)節點的左子樹時,點集合(2,3),(4,7),(5,4)此時的切分維度為(y),從3,4,7選取中位數4,中值為(5,4)作為分割平面,(2,3)掛在其左子樹,(4,7)掛在其右子樹。
4.構建(7,2)節點的右子樹時,點集合(8,1),(9,6)此時的切分維度也為(y),中值為(9,6)作為分割平面,(8,1)掛在其左子樹。至此k-d tree構建完成。
上述的構建過程結合下圖可以看出,構建一個k-d tree即是將一個二維平面逐步划分的過程。
(二)代碼實現構建kd樹
class Node: def __init__(self,data,sp=0,left=None,right=None): self.data = data self.sp = sp #0是按特征1排序,1是按特征2排序 self.left = left self.right = right def __lt__(self, other): return self.data < other.data
class KDTree: def __init__(self,data): self.dim = data.shape[1] self.root = self.createTree(data,0) self.nearest_node = None self.nearest_dist = np.inf #設置無窮大 def createTree(self,dataset,sp): if len(dataset) == 0: return None dataset_sorted = dataset[np.argsort(dataset[:,sp])] #按特征列進行排序 #獲取中位數索引 mid = len(dataset) // 2 #生成節點 left = self.createTree(dataset_sorted[:mid],(sp+1)%self.dim) right = self.createTree(dataset_sorted[mid+1:],(sp+1)%self.dim) parentNode = Node(dataset_sorted[mid],sp,left,right) return parentNode
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]) kdtree = KDTree(data) #創建KDTree
二:kd樹搜索(找最近鄰節點)
注意:最近鄰---當k為1時,稱為最近鄰。
在k-d樹中進行數據的查找也是特征匹配的重要環節,其目的是檢索在k-d樹中與查詢點距離最近的數據點。
(一)簡單案例一:查詢的點(2.1,3.1)
1.通過二叉搜索,從根節點順着搜索路徑很快就能找到最鄰近的近似點,也就是葉子節點(2,3)。
2.而找到的葉子節點並不一定就是最鄰近的,最鄰近肯定距離查詢點更近,應該位於以查詢點為圓心且通過葉子節點的圓域內。
3.為了找到真正的最近鄰,還需要進行'回溯'操作:
算法沿搜索路徑反向查找是否有距離查詢點更近的數據點。
推導:
1.此例中先從(7,2)點開始進行二叉查找,然后到達(5,4),最后到達(2,3),此時搜索路徑中的節點為<(7,2),(5,4),(2,3)>。
2.首先以(2,3)作為當前最近鄰點,計算其到查詢點(2.1,3.1)的距離為0.1414,
3.然后回溯到其父節點(5,4),並判斷在該父節點的其他子節點空間中是否有距離查詢點更近的數據點。以(2.1,3.1)為圓心,以0.1414為半徑畫圓,如圖3所示。發現該圓並不和超平面y = 4交割,因此不用進入(5,4)節點右子空間中去搜索。
4.4、最后,再回溯到(7,2),以(2.1,3.1)為圓心,以0.1414為半徑的圓更不會與x = 7超平面交割,因此不用進入(7,2)右子空間進行查找。至此,搜索路徑中的節點已經全部回溯完,結束整個搜索,返回最近鄰點(2,3),最近距離為0.1414。
(二)案例二:查找點為(2,4.5)
1.同樣先進行二叉查找,先從(7,2)查找到(5,4)節點,在進行查找時是由y = 4為分割超平面的,由於查找點為y值為4.5,因此進入右子空間查找到(4,7),形成搜索路徑<(7,2),(5,4),(4,7)>
2.取(4,7)為當前最近鄰點,計算其與目標查找點的距離為3.202。
3.然后回溯到(5,4),計算其與查找點之間的距離為3.041。((4,7)與目標查找點的距離為3.202,而(5,4)與查找點之間的距離為3.041,所以(5,4)為查詢點的最近點;)
4.以(2,4.5)為圓心,以3.041為半徑作圓,如圖4所示。可見該圓和y = 4超平面交割,所以需要進入(5,4)左子空間進行查找。此時需將(2,3)節點加入搜索路徑中得<(7,2),(2,3)>。
5.回溯至(2,3)葉子節點,(2,3)距離(2,4.5)比(5,4)要近,所以最近鄰點更新為(2,3),最近距離更新為1.5。
6.回溯至(7,2),以(2,4.5)為圓心1.5為半徑作圓,並不和x = 7分割超平面交割。
至此,搜索路徑回溯完。返回最近鄰點(2,3),最近距離1.5。
(三)代碼實現
import numpy as np class Node: def __init__(self,data,sp=0,left=None,right=None): self.data = data self.sp = sp #0是按特征1排序,1是按特征2排序 self.left = left self.right = right def __lt__(self, other): return self.data < other.data
class KDTree: def __init__(self,data): self.dim = data.shape[1] self.root = self.createTree(data,0) self.nearest_node = None self.nearest_dist = np.inf #設置無窮大 def createTree(self,dataset,sp): if len(dataset) == 0: return None dataset_sorted = dataset[np.argsort(dataset[:,sp])] #按特征列進行排序 #獲取中位數索引 mid = len(dataset) // 2 #生成節點 left = self.createTree(dataset_sorted[:mid],(sp+1)%self.dim) right = self.createTree(dataset_sorted[mid+1:],(sp+1)%self.dim) parentNode = Node(dataset_sorted[mid],sp,left,right) return parentNode def nearest(self, x): def visit(node): if node != None: dis = node.data[node.sp] - x[node.sp] #訪問子節點 visit(node.left if dis > 0 else node.right) #查看當前節點到目標節點的距離 二范數求距離 curr_dis = np.linalg.norm(x-node.data,2) #更新節點 if curr_dis < self.nearest_dist: self.nearest_dist = curr_dis self.nearest_node = node #比較目標節點到當前節點距離是否超過當前超平面,超過了就需要到另一個子樹中 if self.nearest_dist > abs(dis): #要到另一面查找 所以判斷條件與上面相反 visit(node.left if dis < 0 else node.right) #從根節點開始查找 node = self.root visit(node) return self.nearest_node.data,self.nearest_dist
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]) kdtree = KDTree(data) #創建KDTree node,dist = kdtree.nearest(np.array([6,5])) print(node,dist)
(四)性能對比
https://www.cnblogs.com/21207-ihome/p/6084670.html
一般來講,最臨近搜索只需要檢測幾個葉子結點即可,如下圖所示:
但是,如果當實例點的分布比較糟糕時,幾乎要遍歷所有的結點,如下所示:
三:K-近鄰算法中kd樹搜索最近K個節點
補充:python內部沒有實現大頂堆,應該如何處理?
將原來得x值,變為-x即可
(一)算法思路(借助堆排序---heapq)
我們借助大小為k得大頂堆來實現我們K-近鄰算法:
1.首先,從根節點向下查找到葉節點
2.從葉節點開始回溯,記錄每一個距離目標點的距離到最大堆中。
(1)如果堆的大小<k,則正常回溯,並且如果到了根節點,我們也要去訪問另一側子樹
(2)如果堆的大小=k,我們每一次回溯時取出最大值,查看目標點是否與當前節點的另一側相交,然后決定是否去訪問另一側。當獲取的新的節點距離目標節點更小,則將當前最大距離出堆,將當前值插入,重新排序。直到我們找到的k個元素中的最大值,不再與當前節點另一邊相交即可。
(二)代碼實現
import numpy as np import heapq class Node: def __init__(self,data,sp=0,left=None,right=None): self.data = data self.sp = sp #0是按特征1排序,1是按特征2排序 self.left = left self.right = right self.nearest_dist = -np.inf #我們需要使用最小堆來模擬最大堆,我們設置默認大小-∞,實際就是+∞ def __lt__(self, other): return self.nearest_dist < other.nearest_dist class KDTree: def __init__(self,data): self.k = data.shape[1] self.root = self.createTree(data,0) self.heap = [] #初始化一個堆 def createTree(self,dataset,sp): if len(dataset) == 0: return None dataset_sorted = dataset[np.argsort(dataset[:,sp])] #按特征列進行排序 #獲取中位數索引 mid = len(dataset) // 2 #生成節點 left = self.createTree(dataset_sorted[:mid],(sp+1)%self.k) right = self.createTree(dataset_sorted[mid+1:],(sp+1)%self.k) parentNode = Node(dataset_sorted[mid],sp,left,right) return parentNode def nearest(self, x, k): def visit(node): if node != None: dis = node.data[node.sp] - x[node.sp] #訪問子節點 visit(node.left if dis > 0 else node.right) #查看當前節點到目標節點的距離 二范數求距離 curr_dis = np.linalg.norm(x-node.data,2) node.nearest_dist = -curr_dis #更新節點 if len(self.heap) < k: #直接加入 heapq.heappush(self.heap,node) else: #先獲取最大堆最大值,比較后決定 if nsmallest(1,self.heap)[0].nearest_dist < -curr_dis: heapq.heapreplace(self.heap, node) #比較目標節點到當前節點距離是否超過當前超平面,超過了就需要到另一個子樹中 if len(self.heap) < k or abs(nsmallest(1,self.heap)[0].nearest_dist) > abs(dis): #要到另一面查找 所以判斷條件與上面相反 visit(node.left if dis < 0 else node.right) #從根節點開始查找 node = self.root visit(node) nds = nlargest(k,self.heap) for i in range(k): nd = nds[i] print(nd.data,nd.nearest_dist)
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]) kdtree = KDTree(data) #創建KDTree kdtree.nearest(np.array([6,5]),5)
(三)對比原始KNN
import numpy as np import matplotlib.pyplot as plt import pandas as pd def KNNClassfy(preData,dataSet,k): distance = np.sum(np.power(dataSet - preData,2),1) #注意:這里我們不進行開方,可以少算一次 sortDistIdx = np.argsort(distance,0)[:k] #小到大排序,獲取索引 for i in range(k): print(dataSet[sortDistIdx[i]],np.linalg.norm(dataSet[sortDistIdx[i]]-preData,2)) data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]) predata = np.array([6,5]) KNNClassfy(predata,data,5)