KD-tree 講解
by simb351
應神犇junble19768的要求,來水一發KD-tree講解。學習過程中發現關於KD-tree的資源實在是少,在OI中的應用更是少之又少。
雖然這東西很水,隨便嘴一嘴就能口胡出來。所以寫一篇講解。
前置芝士:
1. 基礎BST姿勢
2. 替罪羊樹
3. 優良的空間感 比如能腦補出k維空間QWQ
KD-tree的應用:
KD-tree主要解決對K維數據的管理,比如多維偏序。但是本弱發現目前OI中KD-tree主流考法為
維護二維平面中區間的信息。比如
ION 9102 彈跳 但是好像能被菜雞simba口胡的暴力算幾卡過。
KD-tree的原理:
考慮從一般BST中類比過來---對於其中一個節點其左邊節點的值恆小於它本身,右邊反之。實際上是把所有節點的值從中間分開。
如果說BST是對一個一維線段的分割,那么KD-tree就是對K維空間分割。最終在小的空間內統計答案。說人話就是對K維按順序均勻分割
查哪部分就去哪個塊中查找。因為是按序分割,所以找到答案空間的時間是nlogn至nsqrtn我信你個鬼,不帶O2天天被卡。
為了划分空間
,KD-tree在第i層維護第i%k維的信息,即這一維中比它小的在左子樹,大的在右子樹。對於查詢就像BST一樣就好了。同BST,考慮
KD-tree如何保持自身平衡。由於用方差過於優雅,此處選擇替罪羊樹一樣的思路---拍扁重建。這樣KD-tree就愉快的講完了,撒花。
KD-tree代碼實現:
首先是樹的結點。
struct point { int x[DIM]; //DIM☞維度 x表示一個k維向量 bool operator < (const point X) const { return x[now]<X.x[now]; } /* 考慮分割一個維度時,為了讓分割更均勻,要盡量選最中間的點 now表示當前維護維度,定義小於號來維護中間的點。*/ }// 存儲一個向量 struct node { int l,r;//左右子樹 int sze;//子樹大小 int minn[DIM];//此節點維護的空間中第i維的最小值 int maxx[DIM];//此節點維護的空間中第i維的最大值 point data;//這個點所維護的向量 }//KD-tree上一個節點的定義
然后是維護一個節點的信息。
void update(int pos) { for(int i=0;i<DEM;i++) { tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i]; if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]); if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]); if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]); if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]); } tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1; }/*就是字面意思,維護子樹大小,維護其子樹中能到達某一維度的最大值,最小值。*/
然后是把一個子樹拍扁。
void update(int pos) { for(int i=0;i<DEM;i++) { tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i]; if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]); if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]); if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]); if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]); } tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1; }/*就是字面意思,維護子樹大小,維護其子樹中能到達某一維度的最大值,最小值。*/
接着是將一個序列加到樹上,就是把樹拍扁后再掛到樹上。
void update(int pos) { for(int i=0;i<DEM;i++) { tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i]; if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]); if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]); if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]); if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]); } tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1; }/*就是字面意思,維護子樹大小,維護其子樹中能到達某一維度的最大值,最小值。*/
查看這顆樹要不要拍扁重建。
void check(int& pos,int dim) { if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze) {rev(pos,0); pos=build(1,tree[pos].sze,dim);} }// 字面意思 不平衡就拍扁重建 , alpha是替罪羊的平衡因子
插入單個點。
int insert(int pos,point data,int dim) { if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;} if(data.x[dim]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dim^1); else tree[pos].r=insert(tree[pos].r,data,dim^1); update(pos); check(pos,dim); return pos; }//像平衡樹一樣左右看看加那邊,然后掛上節點。最后check一下不讓樹退化
查詢,就查詢經典問題,給定n個點坐標,以及一個點坐標s 詢問n個點中那個離s最近的點是哪個。
void query(int pos,point data) { ans=min(ans,dist(data,tree[pos].data)); //用當前點信息更新ans int dist_left=INF; int dist_right=INF; if(tree[pos].l) dist_left=get_dist(data,tree[pos].l); if(tree[pos].r) dist_right=get_dist(data,tree[pos].r); // L,R維護是查詢點s到當前點左右子樹所維護空間的距離 if(dist_left<dist_right) { if(dist_left<ans) query(tree[pos].l,data); if(dist_right<ans) query(tree[pos].r,data); } else { if(dist_right<ans) query(tree[pos].r,data); if(dist_left<ans) query(tree[pos].l,data); } // 以當前點為圓心,以ans為半徑畫圓,如果達不到左/右子樹所維護的空間,就不查那邊。 }
沒了,真沒了,寫寫題就好了。
附上bzoj2648代碼,就是上面的問題
#include<bits/stdc++.h> #define DEM 2 #define alpha (1.130/2) #define maxn 1000010 #define INF 0x3f3f3f3f using namespace std; int n,m; int u,v; int now; int ans; int opt; int root; int points; queue<int>Q; struct point { int x[DEM]; bool operator < (const point X) const {return x[now]<X.x[now];} }one[maxn]; struct node { int l,r; int sze; int minn[DEM]; int maxx[DEM]; point data; }tree[maxn]; int dist(point,point); int get_dist(point,int); int New(); void update(int); void rev(int,int); int build(int,int,int); void check(int&,int); int insert(int,point,int); void query(int,point); int main() { cin>>n>>m; for(int i=1;i<=n;i++) cin>>one[i].x[0]>>one[i].x[1]; root=build(1,n,0); for(int i=1;i<=m;i++) { cin>>opt>>u>>v; if(opt==1) {root=insert(root,(point){u,v},0);} else {ans=INF; query(root,(point){u,v}); cout<<ans<<endl;} } } int dist(point A,point B) {return abs(A.x[0]-B.x[0])+abs(A.x[1]-B.x[1]);} int get_dist(point A,int pos) {int ret=0; for(int i=0;i<DEM;i++) ret+=max(0,A.x[i]-tree[pos].maxx[i])+max(0,tree[pos].minn[i]-A.x[i]); return ret;} int New() { if(!Q.empty()) {static int tmp; tmp=Q.front(); Q.pop(); return tmp;} else return ++points; } void update(int pos) { for(int i=0;i<DEM;i++) { tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i]; if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]); if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]); if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]); if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]); } tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1; } void rev(int pos,int num) { if(tree[pos].l) rev(tree[pos].l,num); one[tree[tree[pos].l].sze+num+1]=tree[pos].data; Q.push(pos); if(tree[pos].r) rev(tree[pos].r,tree[tree[pos].l].sze+num+1); } int build(int l,int r,int dem) { if(l>r) return 0; int mid=(l+r)>>1,pos=New(); now=dem; nth_element(one+l,one+mid,one+r+1); tree[pos].data=one[mid]; tree[pos].l=build(l,mid-1,dem^1); tree[pos].r=build(mid+1,r,dem^1); update(pos); return pos; } void check(int& pos,int dem) { if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze) {rev(pos,0); pos=build(1,tree[pos].sze,dem);} } int insert(int pos,point data,int dem) { if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;} if(data.x[dem]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dem^1); else tree[pos].r=insert(tree[pos].r,data,dem^1); update(pos); check(pos,dem); return pos; } void query(int pos,point data) { ans=min(ans,dist(data,tree[pos].data)); int dist_left=INF; int dist_right=INF; if(tree[pos].l) dist_left=get_dist(data,tree[pos].l); if(tree[pos].r) dist_right=get_dist(data,tree[pos].r); if(dist_left<dist_right) { if(dist_left<ans) query(tree[pos].l,data); if(dist_right<ans) query(tree[pos].r,data); } else { if(dist_right<ans) query(tree[pos].r,data); if(dist_left<ans) query(tree[pos].l,data); } }