KNN算法之KD樹(K-dimension Tree)實現 K近鄰查詢


KD樹是一種分割k維數據空間的數據結構,主要應用於多維空間關鍵數據的搜索,如范圍搜索和最近鄰搜索。

KD樹使用了分治的思想,對比二叉搜索樹(BST),KD樹解決的是多維空間內的最近點(K近點)問題。(思想與之前見過的最近點對問題很相似,將所有點分為兩邊,對於可能橫跨划分線的點對再進一步討論)

KD樹用來優化KNN算法中的查詢復雜度。

一、建樹

建立KDtree,主要有兩步操作:選擇合適的分割維度選擇中值節點作為分割節點

分割維度的選擇遵循的原則是,選擇范圍最大的緯度,也即是方差最大的緯度作為分割維度,為什么方差最大的適合作為特征呢?

因為方差大,數據相對“分散”,選取該特征來對數據集進行分割,數據散得更“開”一些。

分割節點的選擇原則是,將這一維度的數據進行排序,選擇正中間的節點作為分割節點,確保節點左邊的點的維度值小於節點的維度值,節點右邊的點的維度值大於節點的維度值。

這兩步步驟影響搜索效率,非常關鍵。

二、搜索K近點

需要的數據結構:最大堆(此處我對距離取負從而用最小堆實現的最大堆,因為python的heapq模塊只有最小堆)、堆棧(用列表實現)

a.利用二叉搜索找到葉子節點並將搜索的結點路徑壓入堆棧stack中

b.通過保存在堆棧中的搜索路徑回溯,直至堆棧中沒有結點了

對於b步驟,需要區分葉子結點和非葉結點:

1、葉子結點

葉子結點:計算與目標點的距離。若候選堆中不足K個點,將葉子結點加入候選堆中;如果K個點夠了,判斷是否比候選堆中距離最小的結點(因為距離取了相反數)還要大, 說明應當加入候選堆;

2、非葉結點

對於非葉結點,處理步驟和葉子結點差不多,只是需要額外考慮以目標點為圓心,最大堆的堆頂元素為半徑的超球體是否和划分當前空間的超平面相交,如果相交說明未訪問的另一邊的空間有可能包含比當前已有的K個近點更近的點,需要搜索另一邊的空間;此外,當候選堆中沒有K個點,那么不管有沒有相交,都應當搜索未訪問的另一邊空間,因為候選堆的點不夠K個。

步驟:計算與目標點的距離
1、若不足K個點,將結點加入候選堆中;
如果K個點夠了,判斷是否比候選堆中距離最小的結點(因為距離取了相反數)還要大。

2、判斷候選堆中的最小距離是否小於Xi離當前超平面的距離(即是否需要判斷未訪問的另一邊要不要搜索)當然如果不足K個點,雖然超平面不相交,依舊要搜索另一邊,直到找到葉子結點,並且把路徑加入回溯棧中。

三、預測

KNN通常用來分類或者回歸問題,筆者已經封裝好了兩種預測的方法。

python代碼實現:

 1 import heapq  2 class KDNode(object):  3     def __init__(self,feature=None,father=None,left=None,right=None,split=None):  4         self.feature=feature # dimension index (按第幾個維度的特征進行划分的)
 5         self.father=father #並沒有用到  6         self.left=left  7         self.right=right  8         self.split=split # X value and Y value (元組,包含特征X和真實值Y)
 9 
 10 class KDTree(object):  11     def __init__(self):  12         self.root=KDNode()  13         pass
 14 
 15     def _get_variance(self,X,row_indexes,feature_index):  16         # X (2D list): samples * dimension
 17         # row_indexes (1D list): choose which row can be calculated
 18         # feature_index (int): calculate which dimension
 19         n = len(row_indexes)  20         sum1 = 0  21         sum2 = 0  22         for id in row_indexes:  23             sum1 = sum1 + X[id][feature_index]  24             sum2 = sum2 + X[id][feature_index]**2
 25 
 26         return sum2/n - (sum1/n)**2
 27 
 28     def _get_max_variance_feature(self,X,row_indexes):  29         mx_var = -1
 30         dim_index = -1
 31         for dim in range(len(X[0])):  32             dim_var = self._get_variance(X,row_indexes,dim)  33             if dim_var>mx_var:  34                 mx_var=dim_var  35                 dim_index=dim  36         # return max variance feature index (int)
 37         return dim_index  38 
 39     def _get_median_index(self,X,row_indexes,feature_index):  40         median_index =  len(row_indexes)//2
 41         select_X = [(idx,X[idx][feature_index]) for idx in row_indexes]  42         sorted_X = select_X  43         sorted(sorted_X,key= lambda x:x[1])  44         #return median index in feature_index dimension (int)
 45         return sorted_X[median_index][0]  46 
 47     def _split_feature(self,X,row_indexes,feature_index,median_index):  48         left_ids = []  49         right_ids = []  50         median_val = X[median_index][feature_index]  51         for id in row_indexes:  52             if id==median_index:  53                 continue
 54             val = X[id][feature_index]  55             if val < median_val:  56  left_ids.append(id)  57             else:  58  right_ids.append(id)  59         # return (left points index and right points index)(list,list)
 60         # 把當前的樣本按feature維度進行划分為兩份
 61         return left_ids, right_ids  62 
 63     def build_tree(self,X,Y):  64         row_indexes =[i for i in range(len(X))]  65         node =self.root  66         queue = [(node,row_indexes)]  67         # BFS創建KD樹
 68         while queue:  69             root,ids = queue.pop(0)  70             if len(ids)==1:  71                 root.feature = 0 #如果是葉子結點,維度賦0
 72                 root.split = (X[ids[0]],Y[ids[0]])  73                 continue
 74             # 選取方差最大的特征維度划分,取樣本的中位數作為median
 75             feature_index = self._get_max_variance_feature(X,ids)  76             median_index = self._get_median_index(X,ids,feature_index)  77             left_ids,right_ids = self._split_feature(X,ids,feature_index,median_index)  78             root.feature = feature_index  79             root.split = (X[median_index],Y[median_index])  80             if left_ids:  81                 root.left = KDNode()  82                 root.left.father = root  83  queue.append((root.left,left_ids))  84             if right_ids:  85                 root.right = KDNode()  86                 root.right.father = root  87  queue.append((root.right,right_ids))  88 
 89     def _get_distance(self,Xi,node,p=2):  90         # p=2 default Euclidean distance
 91         nx = node.split[0]  92         dist = 0  93         for i in range(len(Xi)):  94             dist=dist + (abs(Xi[i]-nx[i])**p)  95         dist = dist**(1/p)  96         return dist  97 
 98     def _get_hyperplane_distance(self,Xi,node):  99         xx = node.split[0] 100         dist = abs(Xi[node.feature] - xx[node.feature]) 101         return dist 102 
103     def _is_leaf(self,node): 104         if node.left is None and node.right is None: 105             return True 106         else: 107             return False 108 
109     def get_nearest_neighbour(self,Xi,K=1): 110         search_paths = [] 111         max_heap = [] #use min heap achieve max heap (因為python只有最小堆)
112         priority_num = 0  # remove same distance
113         heapq.heappush(max_heap,(float('-inf'),priority_num,None)) 114         priority_num +=1
115         node = self.root 116         # 找到離Xi最近的葉子結點
117         while node is not None: 118  search_paths.append(node) 119             if Xi[node.feature] < node.split[0][node.feature]: 120                 node = node.left 121             else: 122                 node = node.right 123 
124         while search_paths: 125             now = search_paths.pop() 126             # 葉子結點:計算與Xi的距離,若不足K個點,將葉子結點加入候選堆中;
127             # 如果K個點夠了,判斷是否比候選堆中距離最小的結點(因為距離取了相反數)還要大,
128             # 說明應當加入候選堆;
129             if self._is_leaf(now): 130                 dist = self._get_distance(Xi,now) 131                 dist = -dist 132                 mini_dist = max_heap[0][0] 133                 if len(max_heap) < K : 134  heapq.heappush(max_heap,(dist,priority_num,now)) 135                     priority_num+=1
136                 elif dist > mini_dist: 137                     _ = heapq.heappop(max_heap) 138  heapq.heappush(max_heap,(dist,priority_num,now)) 139                     priority_num+=1
140             # 非葉結點:計算與Xi的距離
141             # 1、若不足K個點,將結點加入候選堆中;
142             # 如果K個點夠了,判斷是否比候選堆中距離最小的結點(因為距離取了相反數)還要大,
143             # 2、判斷候選堆中的最小距離是否小於Xi離當前超平面的距離(即是否需要判斷另一邊要不要搜索)
144             # 當然如果不足K個點,雖然超平面不相交,依舊要搜索另一邊,
145             # 直到找到葉子結點,並且把路徑加入回溯棧中
146             else : 147                 dist = self._get_distance(Xi, now) 148                 dist = -dist 149                 mini_dist = max_heap[0][0] 150                 if len(max_heap) < K: 151  heapq.heappush(max_heap, (dist, priority_num, now)) 152                     priority_num += 1
153                 elif dist > mini_dist: 154                     _ = heapq.heappop(max_heap) 155  heapq.heappush(max_heap, (dist, priority_num, now)) 156                     priority_num += 1
157 
158                 mini_dist = max_heap[0][0] 159                 if len(max_heap)<K or -(self._get_hyperplane_distance(Xi,now)) > mini_dist: 160                     # search another child tree
161                     if Xi[now.feature] >= now.split[0][now.feature]: 162                         child_node = now.left 163                     else: 164                         child_node = now.right 165                     # record path until find child leaf node
166                     while child_node is not None: 167  search_paths.append(child_node) 168                         if Xi[child_node.feature] < child_node.split[0][child_node.feature]: 169                             child_node = child_node.left 170                         else: 171                             child_node = child_node.right 172         return max_heap 173 
174     def predict_classification(self,Xi,K=1): 175         # 多分類問題預測
176         y =self.get_nearest_neighbour(Xi,K) 177         mp = {} 178         for i in y: 179             if i[2].split[1] in mp: 180                 mp[i[2].split[1]]+=1
181             else: 182                 mp[i[2].split[1]]=1
183         pre_y = -1
184         max_cnt =-1
185         for k,v in mp.items(): 186             if v>max_cnt: 187                 max_cnt=v 188                 pre_y=k 189         return pre_y 190 
191     def predict_regression(self,Xi,K=1): 192         #回歸問題預測
193         pre_y = self.get_nearest_neighbour(Xi,K) 194         return sum([i[2].split[1] for i in pre_y])/K 195 
196 
197 # t =KDTree()
198 # xx = [[3,3],[1,2],[5,6],[999,999],[5,5]]
199 # z = [1,0,1,1,1]
200 # t.build_tree(xx,z)
201 # y=t.predict_regression([4,5.5],K=5)
202 # print(y)

參考資料:

《統計學習方法》——李航著

https://blog.csdn.net/qq_32478489/article/details/82972391?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

https://zhuanlan.zhihu.com/p/45346117

https://www.cnblogs.com/xingzhensun/p/9693362.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM