python kd树 搜索 代码


  kd树就是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,可以运用在k近邻法中,实现快速k近邻搜索。构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,依次选择坐标轴对空间进行切分,选择训练实例点在选定坐标轴上的中位数为切分点。具体kd树的原理可以参考kd树的原理。

  代码是参考《统计学习方法》k近邻 kd树的python实现得到

  首先创建一个类,用于表示树的节点,包括:该节点的值,用于划分左右子树的切分轴,左子树,右子树

class decisionnode: def __init__(self,value=None,col=None,rb=None,lb=None): self.value=value self.col=col self.rb=rb self.lb=lb

  切分点为坐标轴上的中值,下面代码求得一个序列的中值

def median(x): n=len(x) x=list(x) x_order=sorted(x) return x_order[n//2],x.index(x_order[n//2])

然后就可以构造一颗kd树,左子树小于切分点,右子树大于切分点,data是输入的数据

def buildtree(x,j=0): rb=[] lb=[] m,n=x.shape if m==0: return None edge,row=median(x[:,j].copy()) for i in range(m): if x[i][j]>edge: rb.append(i) if x[i][j]<edge: lb.append(i) rb_x=x[rb,:] lb_x=x[lb,:] rightBranch=buildtree(rb_x,(j+1)%n) leftBranch=buildtree(lb_x,(j+1)%n) return decisionnode(x[row,:],j,rightBranch,leftBranch)

  接下来是树的搜索过程,可以用下图表示树的搜索过程,具体过程可以参考kd树的原理。

  

  代码如下:

#搜索树:nearestPoint,nearestValue均为全局变量
def traveltree(node,point): global nearestPoint,nearestValue if node==None: return 
    print(node.value) print('---') col=node.col if point[col]>node.value[col]: traveltree(node.rb,point) if point[col]<node.value[col]: traveltree(node.lb,point) dis=dist(node.value,point) print(dis) if dis<nearestValue: nearestPoint=node nearestValue=dis #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
    if node.rb!=None or node.lb!=None: if abs(point[node.col] - node.value[node.col]) < nearestValue: if point[node.col]<node.value[node.col]: traveltree(node.rb,point) if point[node.col]>node.value[node.col]: traveltree(node.lb,point) def searchtree(tree,aim): global nearestPoint,nearestValue #nearestPoint=None
    nearestValue=float('inf') traveltree(tree,aim) return nearestPoint def dist(x1, x2): #欧式距离的计算 
    return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

 完整代码在此处取

 1 import numpy as np  2 from numpy import array  3 class decisionnode:  4     def __init__(self,value=None,col=None,rb=None,lb=None):  5         self.value=value  6         self.col=col  7         self.rb=rb  8         self.lb=lb  9         
10 #读取数据并将数据转换为矩阵形式 
11 def readdata(filename): 12     data=open(filename).readlines() 13     x=[] 14     for line in data: 15         line=line.strip().split('\t') 16         x_i=[] 17         for num in line: 18             num=float(num) 19  x_i.append(num) 20  x.append(x_i) 21     x=array(x) 22     return x 23 
24 #求序列的中值 
25 def median(x): 26     n=len(x) 27     x=list(x) 28     x_order=sorted(x) 29     return x_order[n//2],x.index(x_order[n//2]) 30 
31 #以j列的中值划分数据,左小右大,j=节点深度%列数 
32 def buildtree(x,j=0): 33     rb=[] 34     lb=[] 35     m,n=x.shape 36     if m==0: return None 37     edge,row=median(x[:,j].copy()) 38     for i in range(m): 39         if x[i][j]>edge: 40  rb.append(i) 41         if x[i][j]<edge: 42  lb.append(i) 43     rb_x=x[rb,:] 44     lb_x=x[lb,:] 45     rightBranch=buildtree(rb_x,(j+1)%n) 46     leftBranch=buildtree(lb_x,(j+1)%n) 47     return decisionnode(x[row,:],j,rightBranch,leftBranch) 48 
49 #搜索树:nearestPoint,nearestValue均为全局变量
50 def traveltree(node,point): 51     global nearestPoint,nearestValue 52     if node==None: return 
53     print(node.value) 54     print('---') 55     col=node.col 56     if point[col]>node.value[col]: 57  traveltree(node.rb,point) 58     if point[col]<node.value[col]: 59  traveltree(node.lb,point) 60     dis=dist(node.value,point) 61     print(dis) 62     if dis<nearestValue: 63         nearestPoint=node 64         nearestValue=dis 65         #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
66     if node.rb!=None or node.lb!=None: 67         if abs(point[node.col] - node.value[node.col]) < nearestValue: 68             if point[node.col]<node.value[node.col]: 69  traveltree(node.rb,point) 70             if point[node.col]>node.value[node.col]: 71  traveltree(node.lb,point) 72         
73 def searchtree(tree,aim): 74     global nearestPoint,nearestValue 75     #nearestPoint=None
76     nearestValue=float('inf') 77  traveltree(tree,aim) 78     return nearestPoint 79         
80     
81 def dist(x1, x2): #欧式距离的计算 
82     return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5  
kdtree

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM