python kd樹 搜索 代碼


  kd樹就是一種對k維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構,可以運用在k近鄰法中,實現快速k近鄰搜索。構造kd樹相當於不斷地用垂直於坐標軸的超平面將k維空間切分,依次選擇坐標軸對空間進行切分,選擇訓練實例點在選定坐標軸上的中位數為切分點。具體kd樹的原理可以參考kd樹的原理。

  代碼是參考《統計學習方法》k近鄰 kd樹的python實現得到

  首先創建一個類,用於表示樹的節點,包括:該節點的值,用於划分左右子樹的切分軸,左子樹,右子樹

class decisionnode: def __init__(self,value=None,col=None,rb=None,lb=None): self.value=value self.col=col self.rb=rb self.lb=lb

  切分點為坐標軸上的中值,下面代碼求得一個序列的中值

def median(x): n=len(x) x=list(x) x_order=sorted(x) return x_order[n//2],x.index(x_order[n//2])

然后就可以構造一顆kd樹,左子樹小於切分點,右子樹大於切分點,data是輸入的數據

def buildtree(x,j=0): rb=[] lb=[] m,n=x.shape if m==0: return None edge,row=median(x[:,j].copy()) for i in range(m): if x[i][j]>edge: rb.append(i) if x[i][j]<edge: lb.append(i) rb_x=x[rb,:] lb_x=x[lb,:] rightBranch=buildtree(rb_x,(j+1)%n) leftBranch=buildtree(lb_x,(j+1)%n) return decisionnode(x[row,:],j,rightBranch,leftBranch)

  接下來是樹的搜索過程,可以用下圖表示樹的搜索過程,具體過程可以參考kd樹的原理。

  

  代碼如下:

#搜索樹:nearestPoint,nearestValue均為全局變量
def traveltree(node,point): global nearestPoint,nearestValue if node==None: return 
    print(node.value) print('---') col=node.col if point[col]>node.value[col]: traveltree(node.rb,point) if point[col]<node.value[col]: traveltree(node.lb,point) dis=dist(node.value,point) print(dis) if dis<nearestValue: nearestPoint=node nearestValue=dis #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
    if node.rb!=None or node.lb!=None: if abs(point[node.col] - node.value[node.col]) < nearestValue: if point[node.col]<node.value[node.col]: traveltree(node.rb,point) if point[node.col]>node.value[node.col]: traveltree(node.lb,point) def searchtree(tree,aim): global nearestPoint,nearestValue #nearestPoint=None
    nearestValue=float('inf') traveltree(tree,aim) return nearestPoint def dist(x1, x2): #歐式距離的計算 
    return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

 完整代碼在此處取

 1 import numpy as np  2 from numpy import array  3 class decisionnode:  4     def __init__(self,value=None,col=None,rb=None,lb=None):  5         self.value=value  6         self.col=col  7         self.rb=rb  8         self.lb=lb  9         
10 #讀取數據並將數據轉換為矩陣形式 
11 def readdata(filename): 12     data=open(filename).readlines() 13     x=[] 14     for line in data: 15         line=line.strip().split('\t') 16         x_i=[] 17         for num in line: 18             num=float(num) 19  x_i.append(num) 20  x.append(x_i) 21     x=array(x) 22     return x 23 
24 #求序列的中值 
25 def median(x): 26     n=len(x) 27     x=list(x) 28     x_order=sorted(x) 29     return x_order[n//2],x.index(x_order[n//2]) 30 
31 #以j列的中值划分數據,左小右大,j=節點深度%列數 
32 def buildtree(x,j=0): 33     rb=[] 34     lb=[] 35     m,n=x.shape 36     if m==0: return None 37     edge,row=median(x[:,j].copy()) 38     for i in range(m): 39         if x[i][j]>edge: 40  rb.append(i) 41         if x[i][j]<edge: 42  lb.append(i) 43     rb_x=x[rb,:] 44     lb_x=x[lb,:] 45     rightBranch=buildtree(rb_x,(j+1)%n) 46     leftBranch=buildtree(lb_x,(j+1)%n) 47     return decisionnode(x[row,:],j,rightBranch,leftBranch) 48 
49 #搜索樹:nearestPoint,nearestValue均為全局變量
50 def traveltree(node,point): 51     global nearestPoint,nearestValue 52     if node==None: return 
53     print(node.value) 54     print('---') 55     col=node.col 56     if point[col]>node.value[col]: 57  traveltree(node.rb,point) 58     if point[col]<node.value[col]: 59  traveltree(node.lb,point) 60     dis=dist(node.value,point) 61     print(dis) 62     if dis<nearestValue: 63         nearestPoint=node 64         nearestValue=dis 65         #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
66     if node.rb!=None or node.lb!=None: 67         if abs(point[node.col] - node.value[node.col]) < nearestValue: 68             if point[node.col]<node.value[node.col]: 69  traveltree(node.rb,point) 70             if point[node.col]>node.value[node.col]: 71  traveltree(node.lb,point) 72         
73 def searchtree(tree,aim): 74     global nearestPoint,nearestValue 75     #nearestPoint=None
76     nearestValue=float('inf') 77  traveltree(tree,aim) 78     return nearestPoint 79         
80     
81 def dist(x1, x2): #歐式距離的計算 
82     return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5  
kdtree

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM