[學習筆記] kd-tree


本文參考這位dalao的題解

前置技能:二叉查找樹

其實kd-tree很簡單的啦

和BST都差不多的啦

就是在划分的時候把每一維都比較一下就行啦

(\(dalao\)的kd-tree教程)

然而本蒟蒻是完全看不懂啊qwq

於是我們從頭講起吧:

step 1

首先,我們回憶一下BST,

它是以一個關鍵字\(val\),來滿足它的兩個性質反正大家都知道就懶得寫了.

而kd-tree,則是對於一個\(k\)維的點(也就是有\(k\)個關鍵字),

來弄一個像BST的數據結構.

下面以2d-tree為例(也就是平面內的點)來介紹一下吧.

首先,看圖:

upd:圖片出鍋了.

(蒟蒻畫圖水平有限勉強看下吧)

如果我們只以一維來給這些點排序的話,(假設就以\(y\)軸)

我們會發現,\(x\)軸就沒有了用處,

並且,中間的幾個點還很尷尬(\(y\)軸都一樣..)

因此,我們有一種划分的方法:

將每一維交替着划分.

比如說,我們這一層是以\(x\)軸划分的,

那么下一層就是以\(y\)軸划分.

這樣建出來的樹也很出色不要問我為什么人家也是蒟蒻qwq

step 2

接下來,就要正式講建樹了!

其實,划分的過程在上面已經講了.

但是,為了保持樹的平衡,

我們在建樹的時候,可以直接取中間的點.

依然以上一張圖為例吧,

首先,我們以\(x\)軸來划分,

那么中間的點顯然就是這個紅色的:

然后我們在將其它的點分成兩部分:

(加粗的線即為分割線)

更加 直觀一點的話,就是這樣:

而我們建的樹,就長這樣(其實才就一個點):

接下來,我們再在它的兩個兒子中以\(y\)軸來分,

由於有多個一樣的點,我們隨便找一個:

切開后就是這樣:

而樹就長這樣:

然后,我們再一個個分,最后就成了這樣:

總之,就是說,在划分的時候,

我們先找到中間的那個點,將兩邊分割開來,

再對於兩個兒子以另一維來分割.

並且,頭文件algorithm還有一個方便的操作——函數:nth_element.

它能將序列中第\(k\)大的數放在第\(k\)位,

\(k\)小的放在前面,比\(k\)大的放在后面(但是沒有排序,也就是僅僅於第\(k\)大的比較).

代碼如下:

nth_element(a+l,a+k,a+r+1,cmp);

所以說,建樹的代碼也可以出來了:

inline int New(){
	if(top) return sta[top--];//這個地方先埋個坑(先不管它)
	return ++tot;
}

bool cmp(node a,node b){return a.pla[now]<b.pla[now];}//now表示現在比較的是第幾維

inline int build(int l,int r,int opt){
	if(l>r) return 0;
	int x=New(),mid=(l+r)>>1;now=opt;
	nth_element(a+l,a+mid,a+r+1,cmp);t[x].place=a[mid];//這里表示當前的點的位置
	t[x].ls=build(l,mid-1,opt^1);t[x].rs=build(mid+1,r,opt^1);
	update(x);return x;//update等下會講的
}

(感覺埋了好多坑了...)

step 3

接下來,讓我們了解下每個節點儲存的信息.(順便說一句,本人沉迷於\(struct\))

\(ls,rs\):左兒子,右兒子.

\(size\):子樹大小.

\(place\):一個結構體,表示點的位置.

\(mx[k]\):在當前節點的子樹中第\(k\)維坐標最大值.

\(mi[k]\):在當前節點的子樹中第\(k\)維坐標最小值.

其中,\(mx[k],mi[k]\)表示了當前節點及其子樹的管轄范圍(在查詢時有用),

因此\(update\)就是來更新\(mx,mi,size\)的:

inline void update(int p){
	for(int i=0;i<=1;i++){
		t[p].mx[i]=t[p].mn[i]=t[p].place.pla[i];
		if(t[p].ls) t[p].mx[i]=max(t[p].mx[i],t[t[p].ls].mx[i]),t[p].mn[i]=min(t[p].mn[i],t[t[p].ls].mn[i]);
		if(t[p].rs) t[p].mx[i]=max(t[p].mx[i],t[t[p].rs].mx[i]),t[p].mn[i]=min(t[p].mn[i],t[t[p].rs].mn[i]);
	}
	t[p].size=t[t[p].ls].size+t[t[p].rs].size+1;
}

比如,我們拿之前的圖,

紅色的框就代表紅色節點的范圍:

而這范圍有什么用呢?

別急,講查詢的時候就知道了.

首先,我們來講最近點(曼哈頓距離),

我們假設要查詢圖中離點\((4,7)\)最近的點.

(先把原圖放出來,紅色的為查詢的點)

那么首先,我們找到了第一個節點(5,5),

先統計答案,

inline int dis(node a,node b){//node表示的是圖中的點
    return abs(a.pla[0]-b.pla[0])+abs(a.pla[1]-b.pla[1]);
}

然后我們計算到它兩個兒子的范圍的最短距離.

因為可能答案就在兒子的子樹中,因此我們計算的是到達范圍的最短距離(而不是到達兒子本身)

這時候,\(mx\)\(mi\)就有用了,

inline int getdis(node a,int p){//p是子樹節點的編號
	int ret=0;
	for(int i=0;i<=1;i++) ret+=max(0,a.pla[i]-t[p].mx[i])+max(0,t[p].mn[i]-a.pla[i]);
	return ret;
}

如果最短距離都大於等於\(ans\)的話,那這棵子樹就沒必要搜了.

另外,由於kd-tree的本質是搜索+剪枝,

因此,我們可以在查詢的時候,先搜索最短距離短的子樹,

因為\(ans\)會在搜索時更新,

所以說不定在搜完一棵后另一棵就會被減掉了.

然后,查詢的代碼就出來了:

inline void query(node ret,int p){
	ans=min(ans,dis(ret,t[p].place));
	int teml=INF,temr=INF;
	if(t[p].ls)	teml=getdis(ret,t[p].ls);
	if(t[p].rs)	temr=getdis(ret,t[p].rs);
	if(teml<temr){
		if(teml<ans) query(ret,t[p].ls);
		if(temr<ans) query(ret,t[p].rs);		
	}
	else{
		if(temr<ans) query(ret,t[p].rs);
		if(teml<ans) query(ret,t[p].ls);
	}
}

step 4

講完了查詢,我們來講插入吧.

其實這就和BST一樣啦.

一直比較到空節點在插入就行啦.

inline void insert(node ret,int &p,int opt){
	if(!p){p=New();t[p].place=ret;t[p].ls=t[p].rs=0;update(p);return ;}
	if(ret.pla[opt]<=t[p].place.pla[opt]) insert(ret,t[p].ls,opt^1);
	else insert(ret,t[p].rs,opt^1);
	update(p);check(p,opt);
}

然而,會有一件細思極恐的事情:

在插入多了后,我們的樹可能會退化成一條鏈!

所以,我們要利用替罪羊樹的思想,

設一個值\(\alpha=0.75\)(當然想設其它的也可以),

當某點的\(size*\alpha\)小於它某棵子樹的\(size\)時,就直接拍扁重建.

\(size\):終於想起我了

而代碼也很簡單:

inline int New(){
	if(top) return sta[top--];//這下知道什么意思了吧(拍扁重建時直接返回節點就好)
	return ++tot;
}

inline void pia(int p,int cnt){//有聲音的代碼[滑稽]
	if(t[p].ls) pia(t[p].ls,cnt);//cnt表示已經存了多少個點了
	a[cnt+t[t[p].ls].size+1]=t[p].place,sta[++top]=p;//拍扁后用一個棧來存節點
	if(t[p].rs) pia(t[p].rs,cnt+t[t[p].ls].size+1);
}

inline void check(int &p,int opt){//判斷是否需要重建
	if(t[p].size*alpha<t[t[p].ls].size||t[p].size*alpha<t[t[p].rs].size)
		pia(p,0),p=build(1,t[p].size,opt);
}

那么到這里,kd-tree就基本講完啦!

step 5

來看例題吧:洛谷P4169 [Violet]天使玩偶/SJY擺棋子

這題就是板子了(當然也可以用CDQ分治寫).

上代碼吧:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define INF 0x3f3f3f3f
using namespace std;

inline int read(){
	int sum=0,f=1;char c=getchar();
	while(c>'9'||c<'0'){if(c=='-') f=-1;c=getchar();}
	while(c<='9'&&c>='0'){sum=sum*10+c-'0';c=getchar();}
	return sum*f;
}
const double alpha=0.75;
struct node{int pla[2];}a[2000001];
struct tree{int mx[2],mn[2],size,ls,rs;node place;}t[2000001];
int n,m,rt,tot,now,ans;
int sta[2000001],top=0;

inline int New(){
	if(top) return sta[top--];
	return ++tot;
}

bool cmp(node a,node b){return a.pla[now]<b.pla[now];}

inline void update(int p){
	for(int i=0;i<=1;i++){
		t[p].mx[i]=t[p].mn[i]=t[p].place.pla[i];
		if(t[p].ls) t[p].mx[i]=max(t[p].mx[i],t[t[p].ls].mx[i]),t[p].mn[i]=min(t[p].mn[i],t[t[p].ls].mn[i]);
		if(t[p].rs) t[p].mx[i]=max(t[p].mx[i],t[t[p].rs].mx[i]),t[p].mn[i]=min(t[p].mn[i],t[t[p].rs].mn[i]);
	}
	t[p].size=t[t[p].ls].size+t[t[p].rs].size+1;
}

inline int build(int l,int r,int opt){
	if(l>r) return 0;
	int x=New(),mid=(l+r)>>1;now=opt;
	nth_element(a+l,a+mid,a+r+1,cmp);t[x].place=a[mid];
	t[x].ls=build(l,mid-1,opt^1);t[x].rs=build(mid+1,r,opt^1);
	update(x);return x;
}

inline void pia(int p,int cnt){
	if(t[p].ls) pia(t[p].ls,cnt);
	a[cnt+t[t[p].ls].size+1]=t[p].place,sta[++top]=p;
	if(t[p].rs) pia(t[p].rs,cnt+t[t[p].ls].size+1);
}

inline void check(int &p,int opt){
	if(t[p].size*alpha<t[t[p].ls].size||t[p].size*alpha<t[t[p].rs].size)
		pia(p,0),p=build(1,t[p].size,opt);
}

inline void insert(node ret,int &p,int opt){
	if(!p){p=New();t[p].place=ret;t[p].ls=t[p].rs=0;update(p);return ;}
	if(ret.pla[opt]<=t[p].place.pla[opt]) insert(ret,t[p].ls,opt^1);
	else insert(ret,t[p].rs,opt^1);
	update(p);check(p,opt);
}

inline int getdis(node a,int p){
	int ret=0;
	for(int i=0;i<=1;i++) ret+=max(0,a.pla[i]-t[p].mx[i])+max(0,t[p].mn[i]-a.pla[i]);
	return ret;
}

inline int dis(node a,node b){return abs(a.pla[0]-b.pla[0])+abs(a.pla[1]-b.pla[1]);}

inline void query(node ret,int p){
	ans=min(ans,dis(ret,t[p].place));
	int teml=INF,temr=INF;
	if(t[p].ls)	teml=getdis(ret,t[p].ls);
	if(t[p].rs)	temr=getdis(ret,t[p].rs);
	if(teml<temr){
		if(teml<ans) query(ret,t[p].ls);
		if(temr<ans) query(ret,t[p].rs);		
	}
	else{
		if(temr<ans) query(ret,t[p].rs);
		if(teml<ans) query(ret,t[p].ls);
	}
}

int main(){
	n=read();m=read();
	for(int i=1;i<=n;i++) a[i].pla[0]=read(),a[i].pla[1]=read();
	rt=build(1,n,0);
	for(int	i=1;i<=m;i++){
		int opt=read();node ret;
		ret.pla[0]=read();ret.pla[1]=read();
		if(opt==1) insert(ret,rt,0);
		else if(opt==2) ans=INF,query(ret,rt),printf("%d\n",ans);
	}
	return 0;
}

可能還會更新(埋坑)...


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM