KDtree淺談
1.對KDtree的理解
首先要知道$KDtree$的用處,$KDtree$是用來進行多維數點的,一般這些點都是在在而二維及二維以上,因為一維上的問題,我們基本都可以運用線段樹來解決。我對$KDtree$的理解就是一個自帶剪枝的暴力,並且這個剪枝因為我們對這些多維上的點的較優秀的排列而顯得十分有用。
2.前置知識
在學習$KDtree$之前要先知道並會運用西面三個知識點:
1) 首先,要會建二叉搜索樹,因為整個$KDtree$就是一顆二叉搜索樹。
2) 還需要知道什么事估價函數,因為剪枝的時候要運用到估價函數。
3) 對空間的想象能力,因為$KDtree$是處理圖形上的問題,所以還需要有一定的空間想象能力。
3.KDTree的講解
因為$KDtree$是一種優美的暴力,並且我們要在上面剪枝,所以我們自然想讓每一次剪枝,剪下去盡可能大的部分,所以我們能想到每一次將區間等大的分割,既然要的等大的分割,又要是二叉搜索樹,我們就要讓中間值作為當前節點,所有比它小的都放在它的左面,比它大的都放在它的右面。
知道大致思路了,就要來定義什么是大小了,因為一個點是在多維里,所以和它有關的值有多個。最好想的就是按讀入的順序,進行排序,第一維作為第一關鍵字,第二維作為第二關鍵字,以此類推。我們根據這些點的維度將它們從小到大排序(下面已二維上的點為例),每一次取當前區間的中間值來建樹。這樣我們就能將整個圖分成下面的形式:
顯然這種分法分出的圖並不是最有利,因為每一點的管轄范圍都太小了。我們考慮另一種分割方式,我們將這些點的排序方式進行改變,我們將排序的關鍵字每一次向順時針進行轉動,即我們第一次排序的第一關鍵字是第一維,第二次是第二維……第$n$次是第$n\%維數+1$維。這樣上面的圖形就可以改變成為:
這樣我們在剪枝的時候就能剪去更多的節點。
知道了如何去排序,我們現在就要知道怎么來找中間值。在函數庫里面有一個函數$nth\_element$,這個就能實現我們要的功能。這個函數不知道實現的話,可以上網上找一下學習一下。我們在建樹的時候要維護出來幾個值,這幾個值的運用在下面會進行講解。這幾個值是$mn[0],mx[0],mn[1],mx[1]$,分別表示以當前節點為根的子樹第一維的最小值和最大值,第二維的最小值和最大值,這樣我們在建樹的時候應該更新。
struct Node {long long pla[2],mn[2],mx[2];int id,lson,rson;}node[N]; bool cmp(const Node &a,const Node &b) {return a.pla[sta]<b.pla[sta];} void up(int p,int k) { node[p].mn[0]=min(node[p].mn[0],node[k].mn[0]); node[p].mx[0]=max(node[p].mx[0],node[k].mx[0]); node[p].mn[1]=min(node[p].mn[1],node[k].mn[1]); node[p].mx[1]=max(node[p].mx[1],node[k].mx[1]); } int build(int l,int r,int now) { sta=now;int mid=(l+r)>>1; nth_element(node+l,node+mid,node+r+1,cmp); node[mid].mn[0]=node[mid].mx[0]=node[mid].pla[0]; node[mid].mn[1]=node[mid].mx[1]=node[mid].pla[1]; if(l!=mid) node[mid].lson=build(l,mid-1,(now+1)%2); if(r!=mid) node[mid].rson=build(mid+1,r,(now+1)%2); if(node[mid].lson) up(mid,node[mid].lson); if(node[mid].rson) up(mid,node[mid].rson); return mid; }
建樹之后,我們就可以在里面進行一些操作,比如找離定點的最遠點,最近點,維護矩形內信息等等,下面就是一些估價函數的代碼,以及矩形內區間賦值。
找離當前點的最近點的估價函數及查詢(歐幾里得距離):
long long dis(int p) {return squ(node[p].pla[0]-x)+squ(node[p].pla[1]-y);} long long getdis(int p) { long long tmp=0; tmp+=squ(max(abs(node[p].mx[0]-x),abs(node[p].mn[0]-x))); tmp+=squ(max(abs(node[p].mx[1]-y),abs(node[p].mn[1]-y))); return tmp; } void ask(int p) { long long tmp=dis(p);tmpx.dis=tmp,tmpx.id=node[p].id; if(q.top().dis<=tmpx.dis) q.push(tmpx),q.pop(); long long tmpl=(node[p].lson)?getdis(node[p].lson):-inf; long long tmpr=(node[p].rson)?getdis(node[p].rson):-inf; if(tmpl>tmpr) { if(tmpl>=q.top().dis&&node[p].lson) ask(node[p].lson); if(tmpr>=q.top().dis&&node[p].rson) ask(node[p].rson); } else { if(tmpr>=q.top().dis&&node[p].rson) ask(node[p].rson); if(tmpl>=q.top().dis&&node[p].lson) ask(node[p].lson); } }
找離當前點的最遠點的估價函數及查詢(曼哈頓距離):
int getdis_mx(int p) { int tmp=0; tmp+=max(abs(node[p].mx[0]-x),abs(node[p].mn[0]-x)); tmp+=max(abs(node[p].mx[1]-y),abs(node[p].mn[1]-y)); return tmp; } void ask_mx(int p) { int tmp=abs(node[p].pla[0]-x)+abs(node[p].pla[1]-y); if(tmp>lenth_mx) lenth_mx=tmp; int tmpl=(node[p].lson)?(getdis_mx(node[p].lson)):-inf; int tmpr=(node[p].rson)?(getdis_mx(node[p].rson)):-inf; if(tmpl>tmpr) { if(tmpl>lenth_mx) ask_mx(node[p].lson); if(tmpr>lenth_mx) ask_mx(node[p].rson); } else { if(tmpr>lenth_mx) ask_mx(node[p].rson); if(tmpl>lenth_mx) ask_mx(node[p].lson); } }
找離當前點的最遠點的估價函數及查詢(曼哈頓距離):
int getdis_mn(int p) { int tmp=0; if(x<node[p].mn[0]) tmp+=node[p].mn[0]-x; if(x>node[p].mx[0]) tmp+=x-node[p].mx[0]; if(y<node[p].mn[1]) tmp+=node[p].mn[1]-y; if(y>node[p].mx[1]) tmp+=y-node[p].mx[1]; return tmp; } void ask_mn(int p) { int tmp=abs(node[p].pla[0]-x)+abs(node[p].pla[1]-y); if(tmp&&tmp<lenth_mn) lenth_mn=tmp; int tmpl=(node[p].lson)?(getdis_mn(node[p].lson)):inf; int tmpr=(node[p].rson)?(getdis_mn(node[p].rson)):inf; if(tmpl<tmpr) { if(tmpl<lenth_mn) ask_mn(node[p].lson); if(tmpr<lenth_mn) ask_mn(node[p].rson); } else { if(tmpr<lenth_mn) ask_mn(node[p].rson); if(tmpl<lenth_mn) ask_mn(node[p].lson); } }
矩陣賦值,矩陣查找:
void pushdown(int p) { if(!node[p].tag) return; if(node[p].lson) node[node[p].lson].tag=node[node[p].lson].col=node[p].tag; if(node[p].rson) node[node[p].rson].tag=node[node[p].rson].col=node[p].tag; node[p].tag=0; } void change(int p,int w,int x,int y,int z,int col) { if(!p) return; if(node[p].mx[0]<w||node[p].mn[0]>x) return; if(node[p].mx[1]<y||node[p].mn[1]>z) return; pushdown(p); if(node[p].pla[0]>=w&&node[p].pla[0]<=x&& node[p].pla[1]>=y&&node[p].pla[1]<=z) node[p].col=col; if(node[p].mn[0]>=w&&node[p].mx[0]<=x&& node[p].mn[1]>=y&&node[p].mx[1]<=z) {node[p].tag=node[p].col=col;return;} change(node[p].lson,w,x,y,z,col),change(node[p].rson,w,x,y,z,col); } int find(int p,int w,int x,int y,int z) { if(!p) return 0; if(node[p].mx[0]<w||node[p].mn[0]>x) return 0; if(node[p].mx[1]<y||node[p].mn[1]>z) return 0; pushdown(p); if(node[p].pla[0]==w&&node[p].pla[1]==y) return node[p].col; return max(find(node[p].lson,w,x,y,z),find(node[p].rson,w,x,y,z)); }