KD-tree講解


  KD-tree  講解
 by simb351 

  應神犇junble19768的要求,來水一發KD-tree講解。學習過程中發現關於KD-tree的資源實在是少,在OI中的應用更是少之又少。
雖然這東西很水,隨便嘴一嘴就能口胡出來。所以寫一篇講解。

 前置芝士:
1. 基礎BST姿勢
2. 替罪羊樹
3. 優良的空間感 比如能腦補出k維空間QWQ

KD-tree的應用:
  KD-tree主要解決對K維數據的管理,比如多維偏序。但是本弱發現目前OI中KD-tree主流考法為 維護二維平面中區間的信息。比如
ION 9102 彈跳 但是好像能被菜雞simba口胡的暴力算幾卡過。


KD-tree的原理:
  考慮從一般BST中類比過來---對於其中一個節點其左邊節點的值恆小於它本身,右邊反之。實際上是把所有節點的值從中間分開。
如果說BST是對一個一維線段的分割,那么KD-tree就是對K維空間分割。最終在小的空間內統計答案。說人話就是對K維按順序均勻分割
查哪部分就去哪個塊中查找。因為是按序分割,所以找到答案空間的時間是nlogn至nsqrtn我信你個鬼,不帶O2天天被卡。 為了划分空間
,KD-tree在第i層維護第i%k維的信息,即這一維中比它小的在左子樹,大的在右子樹。對於查詢就像BST一樣就好了。同BST,考慮
KD-tree如何保持自身平衡。由於用方差過於優雅,此處選擇替罪羊樹一樣的思路---拍扁重建。這樣KD-tree就愉快的講完了,撒花。
KD-tree代碼實現:
首先是樹的結點。
  
 
struct point
{
    int x[DIM]; //DIM☞維度 x表示一個k維向量 
    bool operator < (const point X) const
    {
        return x[now]<X.x[now];
    }
    /* 考慮分割一個維度時,為了讓分割更均勻,要盡量選最中間的點 now表示當前維護維度,定義小於號來維護中間的點。*/
}// 存儲一個向量 
struct node
{
    int l,r;//左右子樹
    int sze;//子樹大小
    int minn[DIM];//此節點維護的空間中第i維的最小值
    int maxx[DIM];//此節點維護的空間中第i維的最大值
    point data;//這個點所維護的向量
}//KD-tree上一個節點的定義

 


  
 
  
然后是維護一個節點的信息。
 
void update(int pos)
{
    for(int i=0;i<DEM;i++)
    {
        tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
        if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
        if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
        if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
        if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
    }
    tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
}/*就是字面意思,維護子樹大小,維護其子樹中能到達某一維度的最大值,最小值。*/

 


 
然后是把一個子樹拍扁。
void update(int pos)
{
    for(int i=0;i<DEM;i++)
    {
        tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
        if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
        if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
        if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
        if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
    }
    tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
}/*就是字面意思,維護子樹大小,維護其子樹中能到達某一維度的最大值,最小值。*/

 

接着是將一個序列加到樹上,就是把樹拍扁后再掛到樹上。
 
void update(int pos)
{
    for(int i=0;i<DEM;i++)
    {
        tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
        if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
        if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
        if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
        if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
    }
    tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
}/*就是字面意思,維護子樹大小,維護其子樹中能到達某一維度的最大值,最小值。*/

 

查看這顆樹要不要拍扁重建。
void check(int& pos,int dim)
{
    if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze)
    {rev(pos,0); pos=build(1,tree[pos].sze,dim);}
}// 字面意思 不平衡就拍扁重建 , alpha是替罪羊的平衡因子

 


插入單個點。
 
int insert(int pos,point data,int dim)
{
    if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;}
    if(data.x[dim]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dim^1);
    else tree[pos].r=insert(tree[pos].r,data,dim^1);
    update(pos); check(pos,dim); return pos;
}//像平衡樹一樣左右看看加那邊,然后掛上節點。最后check一下不讓樹退化

 

查詢,就查詢經典問題,給定n個點坐標,以及一個點坐標s 詢問n個點中那個離s最近的點是哪個。
 
void query(int pos,point data)
{
    ans=min(ans,dist(data,tree[pos].data)); //用當前點信息更新ans
    int dist_left=INF; int dist_right=INF;
    if(tree[pos].l) dist_left=get_dist(data,tree[pos].l);
    if(tree[pos].r) dist_right=get_dist(data,tree[pos].r);
    // L,R維護是查詢點s到當前點左右子樹所維護空間的距離
    if(dist_left<dist_right)
    {
        if(dist_left<ans) query(tree[pos].l,data);
        if(dist_right<ans) query(tree[pos].r,data);
    }
    else 
    {
        if(dist_right<ans) query(tree[pos].r,data);
        if(dist_left<ans) query(tree[pos].l,data);
    }
    // 以當前點為圓心,以ans為半徑畫圓,如果達不到左/右子樹所維護的空間,就不查那邊。
}

 


沒了,真沒了,寫寫題就好了。
附上bzoj2648代碼,就是上面的問題
 
#include<bits/stdc++.h>
#define DEM 2
#define alpha (1.130/2)
#define maxn 1000010
#define INF 0x3f3f3f3f
using namespace std;
int n,m;
int u,v;
int now;
int ans;
int opt;
int root;
int points;
queue<int>Q;
struct point
{
    int x[DEM]; 
    bool operator < (const point X) const {return x[now]<X.x[now];}
}one[maxn];
struct node
{
    int l,r;
    int sze;
    int minn[DEM];
    int maxx[DEM];
    point data;
}tree[maxn];
int dist(point,point);
int get_dist(point,int);
int New();
void update(int);
void rev(int,int);
int build(int,int,int);
void check(int&,int);
int insert(int,point,int);
void query(int,point);
int main()
{
    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>one[i].x[0]>>one[i].x[1];
    root=build(1,n,0);
    for(int i=1;i<=m;i++)
    {
        cin>>opt>>u>>v;
        if(opt==1) {root=insert(root,(point){u,v},0);}
        else {ans=INF; query(root,(point){u,v}); cout<<ans<<endl;}  
    } 
}
int dist(point A,point B) {return abs(A.x[0]-B.x[0])+abs(A.x[1]-B.x[1]);}
int get_dist(point A,int pos) {int ret=0; for(int i=0;i<DEM;i++) ret+=max(0,A.x[i]-tree[pos].maxx[i])+max(0,tree[pos].minn[i]-A.x[i]); return ret;}
int New()
{
    if(!Q.empty()) {static int tmp; tmp=Q.front(); Q.pop(); return tmp;}
    else return ++points;
}
void update(int pos)
{
    for(int i=0;i<DEM;i++)
    {
        tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
        if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
        if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
        if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
        if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
    }
    tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
}
void rev(int pos,int num)
{
    if(tree[pos].l) rev(tree[pos].l,num);
    one[tree[tree[pos].l].sze+num+1]=tree[pos].data; Q.push(pos);
    if(tree[pos].r) rev(tree[pos].r,tree[tree[pos].l].sze+num+1);
}
int build(int l,int r,int dem)
{
    if(l>r) return 0;
    int mid=(l+r)>>1,pos=New();
    now=dem; nth_element(one+l,one+mid,one+r+1); tree[pos].data=one[mid]; 
    tree[pos].l=build(l,mid-1,dem^1); tree[pos].r=build(mid+1,r,dem^1);
    update(pos); return pos;
}
void check(int& pos,int dem)
{
    if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze)
    {rev(pos,0); pos=build(1,tree[pos].sze,dem);}
}
int insert(int pos,point data,int dem)
{
    if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;}
    if(data.x[dem]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dem^1);
    else tree[pos].r=insert(tree[pos].r,data,dem^1);
    update(pos); check(pos,dem); return pos;
}
void query(int pos,point data)
{
    ans=min(ans,dist(data,tree[pos].data));
    int dist_left=INF; int dist_right=INF;
    if(tree[pos].l) dist_left=get_dist(data,tree[pos].l);
    if(tree[pos].r) dist_right=get_dist(data,tree[pos].r);
    if(dist_left<dist_right)
    {
        if(dist_left<ans) query(tree[pos].l,data);
        if(dist_right<ans) query(tree[pos].r,data);
    }
    else 
    {
        if(dist_right<ans) query(tree[pos].r,data);
        if(dist_left<ans) query(tree[pos].l,data);
    }
}

 


 
 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM