k-d tree代碼解析


  上一篇較詳細地介紹了k-d樹算法。本文來講解具體的實現代碼。

  首先是一些數據結構的定義。我們先來定義單個數據,代碼如下:

//單個數據向量結構定義
struct _Examplar
{
public:

_Examplar():dom_dims(0){} //數據維度初始化為0
  //帶有完整的兩個參數的constructor,這里const是為了保護原數據不被修改
_Examplar(const std::vector<double> elt, int dims)
{
if(dims > 0)
{
dom_elt = elt;
dom_dims = dims;
}
else
{
dom_dims = 0;
}
}
(一些重載的構造函數和運算符,元素的訪問控制函數等)
        _Examplar(int dims)    //只含有維度信息的constructor
{
if(dims > 0)
{
dom_elt.resize(dims);
dom_dims = dims;
}
else
{
dom_dims = 0;
}
}
_Examplar(const _Examplar& rhs) //copy-constructor
{
if(rhs.dom_dims > 0)
{
dom_elt = rhs.dom_elt;
dom_dims = rhs.dom_dims;
}
else
{
dom_dims = 0;
}
}
_Examplar& operator=(const _Examplar& rhs) //重載"="運算符
{
if(this == &rhs)
return *this;

releaseExamplarMem();

if(rhs.dom_dims > 0)
{
dom_elt = rhs.dom_elt;
dom_dims = rhs.dom_dims;
}

return *this;
}
~_Examplar()
{
}
double& dataAt(int dim) //定義訪問控制函數
{
assert(dim < dom_dims);
return dom_elt[dim];
}
double& operator[](int dim) //重載"[]"運算符,實現下標訪問
{
return dataAt(dim);
}
const double& dataAt(int dim) const //定義只讀訪問函數
{
assert(dim < dom_dims);
return dom_elt[dim];
}
const double& operator[](int dim) const //重載"[]"運算符,實現下標只讀訪問
{
return dataAt(dim);
}
void create(int dims) //創建數據向量
{
releaseExamplarMem();
if(dims > 0)
{
dom_elt.resize(dims); //控制數據向量維度
dom_dims = dims;
}
}
int getDomDims() const //獲得數據向量維度信息
{
return dom_dims;
}
void setTo(double val) //數據向量初始化設置
{
if(dom_dims > 0)
{
for(int i=0;i<dom_dims;i++)
{
dom_elt[i] = val;
}
}
}
private:
void releaseExamplarMem() //清除現有數據向量
{
dom_elt.clear();
dom_dims = 0;
}
private:
std::vector<double> dom_elt; //每個數據定義為一個double類型的向量
int dom_dims; //數據向量的維度
};

  結構_Examplar定義了單個數據節點的結構,主要包含的信息有:1.數據向量本身;2.數據向量的維度。接下來定義一整個數據集的結構,代碼如下:

//數據集結構定義
class ExamplarSet : public TrainData //整個數據集類,由一個抽象類TrainData派生
{
private:
//_Examplar *_ex_set;
std::vector<_Examplar> _ex_set; //定義含有若干個_Examplar類數據向量的數據集
int _size; //數據集大小
int _dims; //數據集中每個數據向量的維度
public:
(一些重載的構造函數運算符,元素訪問控制函數等)
        ExamplarSet():_size(0), _dims(0){}
ExamplarSet(std::vector<_Examplar> ex_set, int size, int dims);
ExamplarSet(int size, int dims);
ExamplarSet(const ExamplarSet& rhs);
ExamplarSet& operator=(const ExamplarSet& rhs);
~ExamplarSet(){}

_Examplar& examplarAt(int idx)
{
assert(idx < _size);
return _ex_set[idx];
}
_Examplar& operator[](int idx)
{
return examplarAt(idx);
}
const _Examplar& examplarAt(int idx) const
{
assert(idx < _size);
return _ex_set[idx];
}
void create(int size, int dims);
int getDims() const { return _dims;}
int getSize() const { return _size;}
_HyperRectangle calculateRange();
bool empty() const
{
return (_size == 0);
}
    void sortByDim(int dim);     //按某個方向維的排序函數
bool remove(int idx); //去除數據集中排序后指定位置的數據向量
void push_back(const _Examplar& ex) //添加某個數據向量至數據集末尾
{
_ex_set.push_back(ex);
_size++;
}
int readData(char *strFilePath); //從文件讀取數據集
private:
void releaseExamplarSetMem() //清除現有數據集
{
_ex_set.clear();
_size = 0;
}
};

  類ExamplarSet定義了整個數據集的結構,其包含的主要信息有:1.含有若干個_Examplar類數據向量的數據集;2.數據集的大小;3.每個數據向量的維度。以上兩個結構是整個算法兩個基本的數據結構,這里的代碼只是展示其主要包含的結構信息,詳細的定義及函數實現代碼請參看附件。

  接下來就要定義k-d tree的結構。同樣采用上述由點定義到集定義的思路,我們先來定義k-d tree中一個節點結構,代碼如下:

//k-d tree節點結構定義
class KDTreeNode
{
private:
int _split_dim; //該節點的最大區分度方向維
_Examplar _dom_elt; //該節點的數據向量
_HyperRectangle _range_hr; //表示數據范圍的超矩形結構
public:
KDTreeNode *_left_child, *_right_child, *_parent; //該節點的左右子樹和父節點
(一些重載的構造函數,元素訪問控制函數等)
public:
KDTreeNode():_left_child(0), _right_child(0), _parent(0),
_split_dim(0){}
KDTreeNode(KDTreeNode *left_child, KDTreeNode *right_child,
KDTreeNode *parent, int split_dim, _Examplar dom_elt, _HyperRectangle range_hr):
_left_child(left_child), _right_child(right_child), _parent(parent),
_split_dim(split_dim), _dom_elt(dom_elt), _range_hr(range_hr){}
KDTreeNode(const KDTreeNode &rhs);
KDTreeNode& operator=(const KDTreeNode &rhs);
_Examplar& getDomElt() { return _dom_elt; }
_HyperRectangle& getHyperRectangle(){ return _range_hr; }
int& splitDim(){ return _split_dim; }
void create(KDTreeNode *left_child, KDTreeNode *right_child,
KDTreeNode *parent, int split_dim, _Examplar dom_elt, _HyperRectangle range_hr);
};

  類KDTreeNode就是按照前一篇表1所述定義的。需要注意的是_HyperRectangle這一結構,它表示的就是這一節點所代表的空間范圍Range,其定義如下:

struct _HyperRectangle    //定義表示數據范圍的超矩形結構
{
_Examplar min; //統計數據集中所有數據向量每個維度上最小值組成的一個數據向量
_Examplar max; //統計數據集中所有數據向量每個維度上最大值組成的一個數據向量
(一些重載的構造函數)
        _HyperRectangle() {}
_HyperRectangle(_Examplar mx, _Examplar mn)
{
assert (mx.getDomDims() == mn.getDomDims());
min = mn;
max = mx;
}
_HyperRectangle(const _HyperRectangle& rhs)
{
min = rhs.min;
max = rhs.max;
}
_HyperRectangle& operator= (const _HyperRectangle& rhs)
{
if(this == &rhs)
return *this;
min = rhs.min;
max = rhs.max;
return *this;
}
void create(_Examplar mx, _Examplar mn)
{
assert (mx.getDomDims() == mn.getDomDims());
min = mn;
max = mx;
}
};

  對於整個數據集來說_HyperRectangle表示的就是對全體的統計范圍信息,對部分數據集來說其表示的就是對部分數據的統計范圍信息。還是以上篇中實例中的數據{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}為例,_HyperRectangle表示的統計范圍如圖1所示:

圖1  _HyperRectangle表示的統計范圍

  • 對於根節點(7,2),其所對應的空間范圍是整個數據集,所以根節點(7,2)的_range_hr就是對整個數據集所有維度方向(此例即x,y方向)的數據范圍統計得min = {dom_elt = (2,1),dom_dims = 2},max = {dom_elt = (9,7),dom_dims = 2};
  • 對於中間節點(5,4),其所對應的空間范圍是根節點的左子樹,所以節點(5,4)的_range_hr就是對整個數據集所有維度方向(此例即x,y方向)的數據范圍統計得min = {dom_elt = (2,3),dom_dims = 2},max = {dom_elt = (5,7),dom_dims = 2};
  • 對於葉子節點(4,7),其所對應的空間范圍是節點本身,所以節點(4,7)的_range_hr就是對整個數據集所有維度方向(此例即x,y方向)的 數據范圍統計得min = {dom_elt = (4,7),dom_dims = 2},max = {dom_elt = (4,7),dom_dims = 2};

  最后再進行整個k-d tree結構的定義。代碼如下:

class KDTree    //k-d tree結構定義
{
public:
KDTreeNode *_root; //k-d tree的根節點
public:
KDTree():_root(NULL){}
void create(const ExamplarSet &exm_set); //創建k-d tree,實際上調用createKDTree
void destroy(); //銷毀k-d tree,實際上調用destroyKDTree
~KDTree(){ destroyKDTree(_root); }
std::pair<_Examplar, double> findNearest(_Examplar target); //查找最近鄰點函數,返回值是pair類型
//實際是調用findNearest_i
//查找距離在range范圍內的近鄰點,返回這樣近鄰點的個數,實際是調用findNearest_range
int findNearest(_Examplar target, double range, std::vector<std::pair<_Examplar, double>> &res_nearest);
private:
KDTreeNode* createKDTree(const ExamplarSet &exm_set);
void destroyKDTree(KDTreeNode *root);
std::pair<_Examplar, double> findNearest_i(KDTreeNode *root, _Examplar target);
int findNearest_range(KDTreeNode *root, _Examplar target, double range,
std::vector<std::pair<_Examplar, double>> &res_nearest);

  可見,整個k-d tree結構是由一系列KDTreeNode類的節點構成。整個k-d樹的構建算法和基於k-d樹的最鄰近查找算法主要就是由createKDTree,findNearest_i以及findNearest_range這三個函數完成。代碼分別如下:

  • createKDTree
//KDTree::是由於定義了KDTree的namespace
KDTree::KDTreeNode* KDTree::KDTree::createKDTree( const ExamplarSet &exm_set )
{
if(exm_set.empty())
return NULL;

ExamplarSet exm_set_copy(exm_set);

int dims = exm_set_copy.getDims();
int size = exm_set_copy.getSize();

//計算每個維的方差,選出方差值最大的維
double var_max = -0.1;
double avg, var;
int dim_max_var = -1;
for(int i=0;i<dims;i++)
{
avg = 0;
var = 0;
//求某一維的總和
for(int j=0;j<size;j++)
{
avg += exm_set_copy[j][i];
}
//求平均
avg /= size;
//求方差
for(int j=0;j<size;j++)
{
var += ( exm_set_copy[j][i] - avg ) *
( exm_set_copy[j][i] - avg );
}
var /= size;
if(var > var_max)
{
var_max = var;
dim_max_var = i;
}
}

//確定節點的數據矢量
_HyperRectangle hr = exm_set_copy.calculateRange(); //統計節點空間范圍
exm_set_copy.sortByDim(dim_max_var); //將所有數據向量按最大區分度方向排序
int mid = size / 2;
_Examplar exm_split = exm_set_copy.examplarAt(mid); //取出排序結果的中間節點
exm_set_copy.remove(mid); //將中間節點作為父(根)節點,所有將其從數據集中去除

//確定左右節點
ExamplarSet exm_set_left(0, exm_set_copy.getDims());
ExamplarSet exm_set_right(0, exm_set_copy.getDims());
exm_set_right.remove(0);

int size_new = exm_set_copy.getSize(); //獲得子數據空間大小
for(int i=0;i<size_new;i++) //生成左右子節點
{
_Examplar temp = exm_set_copy[i];
if( temp.dataAt(dim_max_var) <
exm_split.dataAt(dim_max_var) )
exm_set_left.push_back(temp);
else
exm_set_right.push_back(temp);
}

KDTreeNode *pNewNode = new KDTreeNode(0, 0, 0, dim_max_var, exm_split, hr);
pNewNode->_left_child = createKDTree(exm_set_left); //遞歸調用生成左子樹
if(pNewNode->_left_child != NULL) //確認左子樹父節點
pNewNode->_left_child->_parent = pNewNode;
pNewNode->_right_child = createKDTree(exm_set_right); //遞歸調用生成右子樹
if(pNewNode->_right_child != NULL) //確認右子樹父節點
pNewNode->_right_child->_parent = pNewNode;

return pNewNode; //最終返回k-d tree的根節點
}

  整個createKDTree函數完全符合上篇中表2所述。注意其中統計節點空間范圍calculateRange這一函數,其定義如下:

KDTree::_HyperRectangle KDTree::ExamplarSet::calculateRange()
{
assert(_size > 0);
assert(_dims > 0);
_Examplar mn(_dims);
_Examplar mx(_dims);

for(int j=0;j<_dims;j++)
{
mn.dataAt(j) = (*this)[0][j]; //初始化最小范圍向量
mx.dataAt(j) = (*this)[0][j]; //初始化最大范圍向量
}

for(int i=1;i<_size;i++) //統計數據集中每一個數據向量
{
for(int j=0;j<_dims;j++)
{
if( (*this)[i][j] < mn[j] ) //比較每一維,尋找最小值
mn[j] = (*this)[i][j];
if( (*this)[i][j] > mx[j] ) //比較每一維,尋找最大值
mx[j] = (*this)[i][j];
}
}
_HyperRectangle hr(mx, mn);

return hr; //返回一個_HyperRectangle結構
}
  • findNearest_i
std::pair<KDTree::_Examplar, double> KDTree::KDTree::findNearest_i( KDTreeNode *root, _Examplar target )
{
KDTreeNode *pSearch = root;

//堆棧用於保存搜索路徑
std::vector<KDTreeNode*> search_path;

_Examplar nearest;

double max_dist;

while(pSearch != NULL) //首先通過二叉查找得到搜索路徑
{
search_path.push_back(pSearch);
int s = pSearch->splitDim();
if(target[s] <= pSearch->getDomElt()[s])
{
pSearch = pSearch->_left_child;
}
else
{
pSearch = pSearch->_right_child;
}
}

nearest = search_path.back()->getDomElt(); //取路徑中最后的葉子節點為回溯前的最鄰近點
max_dist = Distance_exm(nearest, target);

search_path.pop_back();

//回溯搜索路徑
while(!search_path.empty())
{
KDTreeNode *pBack = search_path.back();
search_path.pop_back();

if( pBack->_left_child == NULL && pBack->_right_child == NULL) //如果是葉子節點,就直接比較距離的大小
{
if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) )
{
nearest = pBack->getDomElt();
max_dist = Distance_exm(pBack->getDomElt(), target);
}
}
else
{
int s = pBack->splitDim();
if( abs(pBack->getDomElt()[s] - target[s]) < max_dist) //以target為圓心,max_dist為半徑的圓和分割面如果
{ //有交割,則需要進入另一邊子空間搜索
if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) )
{
nearest = pBack->getDomElt();
max_dist = Distance_exm(pBack->getDomElt(), target);
}
if(target[s] <= pBack->getDomElt()[s]) //如果target位於左子空間,就應進入右子空間
pSearch = pBack->_right_child;
else
pSearch = pBack->_left_child; //如果target位於右子空間,就應進入左子空間
if(pSearch != NULL)
search_path.push_back(pSearch); //將新的節點加入search_path中
}
}
}

std::pair<_Examplar, double> res(nearest, max_dist);

return res; //返回包含最鄰近點和最近距離的pair
}
  • findNearest_range
int KDTree::KDTree::findNearest_range( KDTreeNode *root, _Examplar target, double range, 
std::vector<std::pair<_Examplar, double>> &res_nearest )
{
if(root == NULL)
return 0;
double dist_sq, dx;
int ret, added_res = 0;
dist_sq = 0;
dist_sq = Distance_exm(root->getDomElt(), target); //計算搜索路徑中每個節點和target的距離

if(dist_sq <= range) {                   //將范圍內的近鄰添加到結果向量res_nearest中
std::pair<_Examplar,double> temp(root->getDomElt(), dist_sq);
res_nearest.push_back(temp);
//結果個數+1
added_res = 1;
}

dx = target[root->splitDim()] - root->getDomElt()[root->splitDim()];
//左子樹或右子樹遞歸的查找
ret = findNearest_range(dx <= 0.0 ? root->_left_child : root->_right_child, target, range, res_nearest);
//當另外一邊可能存在范圍內的近鄰
if(ret >= 0 && fabs(dx) < range) {
added_res += ret;
ret = findNearest_range(dx <= 0.0 ? root->_right_child : root->_left_child, target, range, res_nearest);
}

added_res += ret;
return added_res; //最終返回范圍內的近鄰個數
}

  依然利用前述實例的數據來做測試,查找(2.1,3.1)和(2,4.5)兩點的最近鄰,並查找距離在4以內的所有近鄰。程序運行結果如下:
                     

           圖2  查找(2.1,3.1)的結果                                                       圖3  查找(2,4.5)的結果

 

附件:http://files.cnblogs.com/eyeszjwang/kdtree.rar

轉載請注明:http://www.cnblogs.com/eyeszjwang/articles/2432465.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM