之前兩篇隨筆介紹了kd樹的原理,並用python實現了kd樹的構建和搜索,具體可以參考
kd樹常與knn算法聯系在一起,knn算法通常要搜索k近鄰,而不僅僅是最近鄰,下面的代碼將利用kd樹搜索目標點的k個近鄰。
首先還是創建一個類,用於保存結點的值,左右子樹,以及用於划分左右子樹的切分軸
class decisionnode: def __init__(self,value=None,col=None,rb=None,lb=None): self.value=value self.col=col self.rb=rb self.lb=lb
切分點為坐標軸上的中值,下面代碼求得一個序列的中值
def median(x): n=len(x) x=list(x) x_order=sorted(x) return x_order[n//2],x.index(x_order[n//2])
然后按照左子樹大於切分點,右子樹小於切分點的規則構造kd樹,其中data是輸入的數據
#以j列的中值划分數據,左小右大,j=節點深度%列數
def buildtree(x,j=0): rb=[] lb=[] m,n=x.shape if m==0: return None edge,row=median(x[:,j].copy()) for i in range(m): if x[i][j]>edge: rb.append(i) if x[i][j]<edge: lb.append(i) rb_x=x[rb,:] lb_x=x[lb,:] rightBranch=buildtree(rb_x,(j+1)%n) leftBranch=buildtree(lb_x,(j+1)%n) return decisionnode(x[row,:],j,rightBranch,leftBranch)
接下來就是搜索樹得到k近鄰的過程,與搜索最近鄰的過程大致相同,需要創建一個字典knears,用於存儲k近鄰的點以及與目標點的距離(歐氏距離)
搜索的過程為:
(1)第一步還是遍歷樹,找到目標點所屬區域對應的葉節點
(2)從葉結點依次向上回退,按照尋找最近鄰點的方法回退到父節點,並判斷其另一個子節點對區域內是否可能存在k近鄰點,具體的,在每個結點上進行以下操作:
(a)如果字典中的成員個數不足k個,將該結點加入字典
(b)如果字典中的成員不少於k個,判斷該結點與目標結點之間的距離是否不大於字典中各結點所對應距離的的最大值,如果不大於,便將其加入到字典中
(c)對於父節點來說,如果目標點與其切分軸之間的距離不大於字典中各結點所對應距離的的最大值,便需要訪問該父節點的另一個子節點
(3)每當字典中增加新成員,就按距離值對字典進行降序排序,將得到的列表賦值給poinelist,pointlist[0][1]便是字典中各結點所對應距離的最大值
(4)當回退到根節點並完成對其操作時,pointlist中后k個結點就是目標點的k近鄰
代碼如下:
#搜索樹:輸出目標點的近鄰點
def traveltree(node,aim): global pointlist #存儲排序后的k近鄰點和對應距離
if node==None: return col=node.col if aim[col]>node.value[col]: traveltree(node.rb,aim) if aim[col]<node.value[col]: traveltree(node.lb,aim) dis=dist(node.value,aim) if len(knears)<k: knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作為字典的鍵
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True) elif dis<=pointlist[0][1]: knears.setdefault(tuple(node.value.tolist()),dis) pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True) if node.rb!=None or node.lb!=None: if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]: if aim[node.col]<node.value[node.col]: traveltree(node.rb,aim) if aim[node.col]>node.value[node.col]: traveltree(node.lb,aim) return pointlist
完整代碼在此處取

1 import numpy as np 2 from numpy import array 3 class decisionnode: 4 def __init__(self,value=None,col=None,rb=None,lb=None): 5 self.value=value 6 self.col=col 7 self.rb=rb 8 self.lb=lb 9
10 #讀取數據並將數據轉換為矩陣形式
11 def readdata(filename): 12 data=open(filename).readlines() 13 x=[] 14 for line in data: 15 line=line.strip().split('\t') 16 x_i=[] 17 for num in line: 18 num=float(num) 19 x_i.append(num) 20 x.append(x_i) 21 x=array(x) 22 return x 23
24 #求序列的中值
25 def median(x): 26 n=len(x) 27 x=list(x) 28 x_order=sorted(x) 29 return x_order[n//2],x.index(x_order[n//2]) 30
31 #以j列的中值划分數據,左小右大,j=節點深度%列數
32 def buildtree(x,j=0): 33 rb=[] 34 lb=[] 35 m,n=x.shape 36 if m==0: return None 37 edge,row=median(x[:,j].copy()) 38 for i in range(m): 39 if x[i][j]>edge: 40 rb.append(i) 41 if x[i][j]<edge: 42 lb.append(i) 43 rb_x=x[rb,:] 44 lb_x=x[lb,:] 45 rightBranch=buildtree(rb_x,(j+1)%n) 46 leftBranch=buildtree(lb_x,(j+1)%n) 47 return decisionnode(x[row,:],j,rightBranch,leftBranch) 48
49 #搜索樹:輸出目標點的近鄰點
50 def traveltree(node,aim): 51 global pointlist #存儲排序后的k近鄰點和對應距離
52 if node==None: return
53 col=node.col 54 if aim[col]>node.value[col]: 55 traveltree(node.rb,aim) 56 if aim[col]<node.value[col]: 57 traveltree(node.lb,aim) 58 dis=dist(node.value,aim) 59 if len(knears)<k: 60 knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作為字典的鍵
61 pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True) 62 elif dis<=pointlist[0][1]: 63 knears.setdefault(tuple(node.value.tolist()),dis) 64 pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True) 65 if node.rb!=None or node.lb!=None: 66 if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]: 67 if aim[node.col]<node.value[node.col]: 68 traveltree(node.rb,aim) 69 if aim[node.col]>node.value[node.col]: 70 traveltree(node.lb,aim) 71 return pointlist 72
73 def dist(x1, x2): #歐式距離的計算
74 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5
75
76 knears={} 77 k=int(input('請輸入k的值')) 78 if k<2: print('k不能是1') 79 global pointlist 80 pointlist=[] 81 file=input('請輸入數據文件地址') 82 data=readdata(file) 83 tree=buildtree(data) 84 tmp=input('請輸入目標點') 85 tmp=tmp.split(',') 86 aim=[] 87 for num in tmp: 88 num=float(num) 89 aim.append(num) 90 aim=tuple(aim) 91 pointlist=traveltree(tree,aim) 92 for point in pointlist[-k:]: 93 print(point)