[算法] 數據結構 splay(伸展樹)解析


前言

splay學了已經很久了,只不過一直沒有總結,鴿了好久來寫一篇總結。

先介紹 splay:亦稱伸展樹,為二叉搜索樹的一種,部分操作能在 \(O( \log n)\) 內完成,如插入、查找、刪除、查詢序列第 \(k\) 大、查詢前綴(比查詢的數小的數中最大的數)、查詢后綴(比查詢的數大的數中最小的數)等操作,甚至能夠實現區間平移。它由 Daniel Sleator 和 Robert Endre Tarjan 在1985年發明的。注:時間復雜度是均攤為 \(O(\log n)\) ,是經過嚴謹的證明的,單個操作可能退化成 \(O(n)\)

本文例題鏈接

算法思想

先做一個小小的引入:輸入法中,你經常使用詞語,會在詞條中靠前的位置。實現過程可以使用 splay。

splay 是二叉搜索樹的一種,這里簡單介紹一下二叉搜索樹。

對於一棵二叉樹,滿足樹上任意節點,它的左子樹上任意節點滿足比當前節點的權值小,右子樹上任意節點的權值比當前節點的權值大。則稱這棵樹為二叉搜索樹。

可以利用二叉搜索樹的性質來進行操作,比當前節點的權值小就在左子樹查找,權值大就在右子樹查找。

理想狀態下,若該二叉樹為一顆完全二叉樹,則單次操作時間復雜度為 \(O(\log n)\) 。但這顆二叉樹可能退化成一條鏈,這樣單次時間復雜度為 \(O(n)\)

splay 樹在這上面進行了改進,通過不斷改變樹的形態來保證不會退化,均攤時間復雜度為 \(O(\log n)\) 。基本思想是把搜索頻率高的點放在深度小的位置,為了操作方便,可以認為每次操作的點都是頻率高的。常常把操作的點,或是操作區間的兩個端點放在根或根的附近的位置,那么會涉及到旋轉操作。

根據勢能函數分析(我不會),splay 的時間復雜度上限為 \(O((m+n)\log n)\) ,但這個上限是有波動的。

基本操作

建議配合注釋一起使用。

結構體中應包含以下信息:

struct Splay_Node {
	int son[2], val, cnt, siz, fa;
//分別是:兩個兒子,權值,副本數,子樹大小,父親節點
	#define ls t[pos].son[0] //宏定義左兒子,方便一些
	#define rs t[pos].son[1] //右兒子,同上
};

簡單說明一下,副本數為權值為 val 的數的個數。

New

開辟新節點,里面的值隨需求變化,以下是幾個重要的值。

int New(int val, int fa) {
	t[++tot].fa = fa, t[tot].cnt = t[tot].siz = 1, t[tot].val = val;
	return tot;
}

Build

建立splay樹,將極小值置為根節點,極大值作為根節點的右兒子,滿足二叉搜索樹的性質,代碼:

void Build() {
	root = New(-INF, 0); //極小值為根節點 
	t[root].son[1] = New(INF, root); //極大值為右兒子
	Update(root); //更新根節點信息
}

寫這段代碼的主要原因是:使得 splay 的每個節點不會爆掉邊界,否則很容易就 RE 。

Ident

判斷該節點為父節點的左兒子還是右兒子,左兒子為 \(0\) ,右兒子為 \(1\)

bool Ident(int pos) { return t[t[pos].fa].son[1] == pos; } 

Update

更新子樹大小,還更新節點信息(由需求所定)。

void Update(int pos) {
	t[pos].siz = t[ls].siz + t[rs].siz + t[pos].cnt; //子樹大小為左右子樹大小加上自己的副本數
}

Connect

將一對點變為父子關系。

void Connect(int pos, int fa, int flag) {//依次是:子節點,父節點,哪個兒子
	t[fa].son[flag] = pos;//將fa的兒子置為pos
	t[pos].fa = fa;//將pos的父親置為fa
}

Rotate

既然要把一個點旋轉到根節點,那么就必須先掌握單旋操作,具體分兩個情況討論。

左兒子旋轉至父節點

在這里插入圖片描述

如上圖,需要進行幾次轉換: \(x\) 的左兒子變為 \(y\) 的右兒子, \(y\) 的右兒子變為\(x\)\(a\) 的子節點變為 \(y\)

那么程序可以寫為:

void Rotate(int pos) {//這里的flag1=0,可以按照上述的三個轉換進行驗證這段程序是對的
	int fa = t[pos].fa, grand = t[fa].fa;
	int flag1 = Ident(pos), flag2 = Ident(fa);
	Connect(pos, grand, flag2);
	Connect(t[pos].son[flag1 ^ 1], fa, flag1);
	Connect(fa, pos, flag1 ^ 1);
	Update(fa); Update(pos);
}

右兒子旋轉至父節點

可以視為上圖的逆操作: \(y\) 的右兒子變為 \(x\) 的左兒子, \(x\) 的左兒子變為\(y\)\(a\) 的子節點變為 \(x\)

那么程序依舊可以寫為:

void Rotate(int pos) {//這里的flag1=1,可以按照上述的三個轉換進行驗證這段程序是對的
	int fa = t[pos].fa, grand = t[fa].fa;
	int flag1 = Ident(pos), flag2 = Ident(fa);
	Connect(pos, grand, flag2);
	Connect(t[pos].son[flag1 ^ 1], fa, flag1);
	Connect(fa, pos, flag1 ^ 1);
	Update(fa); Update(pos);
}

綜上所述,Rotate 操作可以不用判斷左右節點,寫法為上述程序。

Splay

聽名字就知道,這是splay樹的核心操作。

函數 \(splay(pos,to)\) 定義為:將編號為 \(x\) 的節點,旋轉至父親為 \(to\) 的節點(即 \(to\) 的其中一個子節點,且進行 splay 后依然滿足二叉搜索樹的性質)。

顯然有一種方法:對於當前節點 \(pos\) ,不停進行 \(Rotate(pos)\) ,知道 \(pos\) 的父節點為 \(to\) 為止。

但是這並不能使該 splay 樹的形態發生太大的改變。splay 的目的是改變樹的形態,有一種改進的方法:雙旋。順帶說明一下,單旋會被卡成 \(O(nm)\) 。(具體我也不知道怎么卡)

雙旋即一次旋轉兩次,設當前點為 \(x\) ,父親節點為 \(y\) ,爺爺為 \(z\) 。具體分為兩種情況,這里只證明正確性。

x、y、z 形成一條鏈

在這里插入圖片描述

這種情況先單旋 \(y\) 在單旋 \(x\) 。過程見下圖:
在這里插入圖片描述

顯然,在上述過程中,嚴謹地滿足了 \(val[x]>val[y]>val[z]\)

x、y、z 形成“<”或 “>”

直接進行兩次單旋操作,正確性顯然。

Code

代碼很短,只有三行。

void Splay(int pos, int to) {
	for(int fa = t[pos].fa; t[pos].fa != to; Rotate(pos), fa = t[pos].fa)
		if(t[fa].fa != to) Ident(pos) == Ident(fa) ? Rotate(fa) : Rotate(pos);
//Ident(pos) == Ident(fa)意味着pos和fa成為了一條鏈的形狀,否則為“<”或“>”。
	if(!to) root = pos;//更新根節點,根節點的父親值為0
}

總結

這些是 splay 的基本操作,之后的所有操作都是建立在這些之上的。

引申操作

Find

定義 \(Find(val)\) :查詢權值為 \(val\) 的點的編號,若沒有該點就返回 \(0\)

利用 splay 為二叉搜索樹的性質,若 \(val\) 小於當前節點的權值,則在左子樹中查找;若大於則在右子樹中查找。知道找到當前節點的編號為 \(0\) 或當前節點的權值等於 \(val\) 的時候返回改點的下標。

int Find(int pos, int val) {
	if(!pos) return 0;//空節點直接返回
	if(val == t[pos].val) return pos;//等於就直接返回節點編號
	else if(val < t[pos].val) return Find(ls, val);//在左子樹中查找
	else return Find(rs, val);//在右子樹中查找
}

Insert

即插入操作, 需要插入權值為 \(val\) 的值。

其思想跟 \(Find\) 函數差不多,利用二叉搜索樹的性質直接就可以找到插入的位置。具體分為兩類:

  1. 有權值為 \(val\) 的點 \(pos\) ,直接使得副本數加 \(1\) 即可。
  2. 沒有權值為 \(val\) 的點 \(pos\) ,則開辟一個新的節點權值為 \(val\)

注意 \(pos\) 應傳實參,因為若開辟了新的節點,其父節點的對應兒子也需要改變。

void Insert(int &pos, int val, int fa) {//pos為實參
	if (!pos) Splay(pos = New(val, fa), 0);
	else if (val == t[pos].val) { ++t[pos].siz, ++t[pos].cnt; Splay(pos, 0); }
	else if (val < t[pos].val) Insert(ls, val, pos);
	else Insert(rs, val, pos);
}

Erase

即刪除操作, \(Erase(val)\) 定義為:刪除所維護的序列中權值為 \(val\) 的一個節點(如果有的話)。

可以先找到權值為 \(val\) 的節點並定義其編號為 \(pos\) ,分兩種情況。

  1. 若當前節點的副本數大於 \(1\) 時,即 \(t[pos].cnt>1\) 時,可以刪除其中一個副本即可,但並沒有刪除這個節點。
  2. 否則,則需要刪除該節點。需要先將 \(pos\) splay 到根節點。找到它的前綴的編號 \(l\) 和它的后綴的編號 \(r\) ,則 \(t[l].val\leq val \leq t[r].val\) 。顯然, \((t[l].val,t[r].val)\) 區間內的數只有一個,即 \(pos\) 。將 \(l\) splay 至根節點, \(r\) splay 至 \(l\) 的右兒子,則 \(pos\) 必會在 \(r\) 的左兒子處,因為 \(l\)\(r\)\(pos\) 必回滿足二叉搜索樹的性質。然后直接刪除 \(r\) 的左兒子即可。
void Erase(int val) {
	int pos = Find(root, val);//找到權值為 val 的點。
	if (!pos) return;//沒有改節點直接返回,沒有難倒刪空氣?
	if (t[pos].cnt > 1) { --t[pos].siz, --t[pos].cnt; Splay(pos, 0); Update(pos); return; }//對應情況1
	Splay(pos, 0);
	int l = ls, r = rs;
	while (t[l].son[1]) l = t[l].son[1];//找到前綴
	while (t[r].son[0]) r = t[r].son[0];//找到后繼
	Splay(l, 0); Splay(r, l);//對應情況2
	t[r].son[0] = 0;
	Update(r); Update(l);
}

這里在提供一種做法,與 \(Find\) 函數的做法類似,可以說是其的升級版。總體框架不變,主要是針對第二種情況,將其旋轉到根節點在進行刪除,這種寫法還是比較常見的。

void Erase(int pos, int val) {
	if(!pos) return;
	if(val == t[pos].val) {
		if(t[pos].cnt > 1) { t[pos].siz--, t[pos].cnt--; Splay(pos, 0); return; }
		if(ls) Rotate(ls), Erase(pos, val);//有左兒子跟左兒子交換
		else if(rs) Rotate(rs), Erase(pos, val);//有右兒子就跟右兒子交換
		else {//沒有兒子就直接刪除,注意必須刪除其父親的對應兒子
			int newroot = t[pos].fa;
			t[t[pos].fa].son[Ident(pos)] = 0;
			Splay(newroot, 0);
		}
		return;
	}
	else if(val < t[pos].val) rase(ls, val);
	else Erase(rs, val);
}

Query_kth

查詢 \(val\) 在序列是第幾大的樹,即按照從小到大的順序排序后, \(val\) 的排名,沒有 \(val\) 輸出返回 \(-1\)

代碼使用遞歸實現,考慮對於當前節點 \(pos\) ,比 \(val\) 小的數都在左子樹內,即有 \(t[ls].siz\) 個樹比 \(t[pos].val\) 小。

對於局部解,可以將 \(Querykth(pos,val)\) 函數理解為 \(pos\) 的子樹中,小於 \(val\) 的值有多少。

則可以分為三種情況來討論。

  1. \(val=t[pos].val\) 時,即找到了該節點,返回比它小的數的個數即可,即左子樹的節點數加 \(1\)
  2. \(val<t[pos].val\) 時, \(val\) 左子樹中,在左子樹中查詢該節點的排名。
  3. \(val>t[pos].val\) 時, 是最麻煩的部分。 \(val\) 右子樹中,左子樹與當前節點都會為答案做貢獻,先將其統計至答案中,在求出右子樹對於答案的貢獻。

注意,最后的答案是包含了極小值的,所以找到后的答案應該減一,這一部分我寫在了主函數里,所以沒找到會輸出 \(-1\)

int Query_kth(int pos, int val) {
	if(!pos) return 0;//沒有輸出-1
	if(val == t[pos].val) { int res = t[ls].siz + 1; Splay(pos, 0); return res; }//對應情況1
	else if(val < t[pos].val) return Query_kth(ls, val);//對於情況2
	//下兩行代碼對應情況3
	int res = t[ls].siz + t[pos].cnt;//找到后splay維護形態會導致子樹的大小變化,因此先記錄答案
	return Query_kth(rs, val) + res;
}

Query_val

查詢區間的第 \(k\) 小的數。

可以看做上一個操作的逆操作吧,若 \(k\) 都大於了區間的所有數的個數,就直接返回極大值。

同樣,對於局部解,可以將 \(Queryval(pos,k)\) 函數理解為 \(pos\) 的子樹中,第 \(k\) 大值為多少。

又可以分為三個情況:

  1. \(t[ls].siz\geq k\) 時,即所求答案在左子樹,在左邊查詢即可。
  2. \(t[ls].siz+t[pos].cnt\geq k\) 時, 答案為 \(t[pos].val\) ,因為第 \(t[ls].siz+1\) 小至 \(t[ls].siz+t[pos].cnt\) 的數全部權值都為 \(t[pos].val\)
  3. 否則,答案全部會在右子樹當中,查詢右子樹第 \(k-t[ls].siz-t[pos].cnt\) 大,因為當前節點與左兒子一定比右子樹任何一個數小。

同樣的需要注意,最后的答案是包含了極小值的,同樣這一部分我寫在了主函數里,查詢的時候需要查詢第 \(k+1\) 大的那個數。

int Query_val(int pos, int rank) {
	if(!pos) return INF;
	if(t[ls].siz >= rank) return Query_val(ls, rank);
	else if(t[ls].siz + t[pos].cnt >= rank) { Splay(pos, 0); return t[pos].val; }
	return Query_val(rs, rank - t[ls].siz - t[pos].cnt);
}

Get_Pre、Get_Nxt

\(Erase\) 操作中提到過,可以使用那樣的做法。

亦可使用在文末的代碼中稍快的做法,與 \(Find\) 函數相似,這里就不多說了。(其實是不想打字了

也可以參照這段代碼將一些操作寫為非遞歸的寫法,會更快一些。

總結

有些細心的同學可能已經發現了,幾乎每個操作都有 splay 操作來維護當前樹的形態,保證時間復雜度。

C++代碼

只是將上述操作拼起來放在一個代碼里。

說明一下操作的幾種類型:

  1. 插入 \(x\) 數。
  2. 刪除 \(x\) 數(若有多個相同的數,因只刪除一個)。
  3. 查詢 \(x\) 數的排名(排名定義為比當前數小的數的個數 \(+1\) )。
  4. 查詢排名為 \(x\) 的數。
  5. \(x\) 的前驅(前驅定義為小於 \(x\),且最大的數)。
  6. \(x\) 的后繼(后繼定義為大於 \(x\),且最小的數)。

不是特別長,實現的方法也並不困難,打的時候必須得注意,完整沒附上注釋的代碼:

#include <cstdio>
#define INF 0x3f3f3f3f
#define Ident(pos) ( t[t[pos].fa].son[1] == pos )
const int MAXN = 1e5 + 5;
struct Splay_Tree {
	int son[2], val, cnt, siz, fa;
	#define ls t[pos].son[0]
	#define rs t[pos].son[1]
};
int root, tot, q;
Splay_Tree t[MAXN];
int New(int val, int fa) {
	t[++tot].fa = fa, t[tot].cnt = t[tot].siz = 1, t[tot].val = val;
	return tot;
}
void Update(int pos) { t[pos].siz = t[ls].siz + t[rs].siz + t[pos].cnt; }
void Build() { root = New(-INF, 0); t[root].son[1] = New(INF, root); Update(root); }
void Connect(int pos, int fa, int flag) { t[fa].son[flag] = pos, t[pos].fa = fa; }
void Rotate(int pos) {
	int fa = t[pos].fa, grand = t[fa].fa;
	int flag1 = Ident(pos), flag2 = Ident(fa);
	Connect(pos, grand, flag2);
	Connect(t[pos].son[flag1 ^ 1], fa, flag1);
	Connect(fa, pos, flag1 ^ 1);
	Update(fa); Update(pos);
}
void Splay(int pos, int to) {
	for (int fa = t[pos].fa; t[pos].fa != to; Rotate(pos), fa = t[pos].fa)
		if (t[fa].fa != to) Ident(pos) == Ident(fa) ? Rotate(fa) : Rotate(pos);
	if (!to) root = pos;
}
int Find(int pos, int val) {
	if (!pos) return 0;
	if (val == t[pos].val) return pos;
	else if (val < t[pos].val) return Find(ls, val);
	return Find(rs, val);
}
void Insert(int &pos, int val, int fa) {
	if (!pos) Splay(pos = New(val, fa), 0);
	else if (val == t[pos].val) { ++t[pos].siz, ++t[pos].cnt; Splay(pos, 0); }
	else if (val < t[pos].val) Insert(ls, val, pos);
	else Insert(rs, val, pos);
}
void Erase(int val) {
	int pos = Find(root, val);
	if (!pos) return;
	if (t[pos].cnt > 1) { --t[pos].siz, --t[pos].cnt; Splay(pos, 0); Update(pos); return; }
	Splay(pos, 0);
	int l = ls, r = rs;
	while (t[l].son[1]) l = t[l].son[1];
	while (t[r].son[0]) r = t[r].son[0];
	Splay(l, 0); Splay(r, l);
	t[r].son[0] = 0;
	Update(r); Update(l);
}
int Query_Rnk(int pos, int val) {
	if (!ls && !rs && val != t[pos].val) { Splay(pos, 0); return 0; }
	else if (val == t[pos].val) { int res = t[ls].siz + 1; Splay(pos, 0); return res; }
	else if (val < t[pos].val) return Query_Rnk(ls, val);
	int res = t[ls].siz + t[pos].cnt;
	return Query_Rnk(rs, val) + res;
}
int Query_Kth(int pos, int rank) {
	if (t[ls].siz >= rank && ls) return Query_Kth(ls, rank);
	else if (t[ls].siz + t[pos].cnt >= rank) { Splay(pos, 0); return t[pos].val; }
	else if (rs) return Query_Kth(rs, rank - t[ls].siz - t[pos].cnt);
	Splay(pos, 0); return 0;
}
int Get_Pre(int val) {
	int pos = root, res = root;
	pos = root;
	while (pos) {
		if (t[pos].val < val) res = pos, pos = rs;
		else pos = ls;
	}
	Splay(res, 0);
	return t[res].val;
}
int Get_Nxt(int val) {
	int pos = root, res = root;
	while (pos) {
		if (t[pos].val > val) res = pos, pos = ls;
		else pos = rs;
	}
	Splay(res, 0);
	return t[res].val;
}
int main() {
	Build(); scanf("%d", &q); 
	for (int i = 1, opt, x; i <= q; i++) {
		scanf("%d %d", &opt, &x);
		if (opt == 1) Insert(root, x, 0);
		else if (opt == 2) Erase(x);
		else if (opt == 3) printf("%d\n", Query_Rnk(root, x) - 1);
		else if (opt == 4) printf("%d\n", Query_Kth(root, x + 1));
		else if (opt == 5) printf("%d\n", Get_Pre(x));
		else printf("%d\n", Get_Nxt(x));
	}
	return 0;
}

補充

原本數據有億點水,錯的代碼都能過。

數據加強版

#include <cstdio>
#define INF 2147483647
#define Ident(pos) ( t[t[pos].fa].son[1] == pos )
const int MAXN = 2e6 + 5;
struct Splay_Tree {
	int son[2], val, cnt, siz, fa;
	#define ls t[pos].son[0]
	#define rs t[pos].son[1]
};
int root, tot, ans, n, m;
Splay_Tree t[MAXN];
int New(int val, int fa) {
	t[++tot].fa = fa, t[tot].cnt = t[tot].siz = 1, t[tot].val = val;
	return tot;
}
void Update(int pos) { t[pos].siz = t[ls].siz + t[rs].siz + t[pos].cnt; }
void Build() { root = New(-INF, 0); t[root].son[1] = New(INF, root); Update(root); }
void Connect(int pos, int fa, int flag) { t[fa].son[flag] = pos, t[pos].fa = fa; }
void Rotate(int pos) {
	int fa = t[pos].fa, grand = t[fa].fa;
	int flag1 = Ident(pos), flag2 = Ident(fa);
	Connect(pos, grand, flag2);
	Connect(t[pos].son[flag1 ^ 1], fa, flag1);
	Connect(fa, pos, flag1 ^ 1);
	Update(fa); Update(pos);
}
void Splay(int pos, int to) {
	for (int fa = t[pos].fa; t[pos].fa != to; Rotate(pos), fa = t[pos].fa)
		if (t[fa].fa != to) Ident(pos) == Ident(fa) ? Rotate(fa) : Rotate(pos);
	if (!to) root = pos;
}
int Find(int pos, int val) {
	if (!pos) return 0;
	if (val == t[pos].val) return pos;
	else if (val < t[pos].val) return Find(ls, val);
	return Find(rs, val);
}
void Insert(int &pos, int val, int fa) {
	if (!pos) Splay(pos = New(val, fa), 0);
	else if (val == t[pos].val) { ++t[pos].siz, ++t[pos].cnt; Splay(pos, 0); }
	else if (val < t[pos].val) Insert(ls, val, pos);
	else Insert(rs, val, pos);
}
void Erase(int val) {
	int pos = Find(root, val);
	if (!pos) return;
	if (t[pos].cnt > 1) { --t[pos].siz, --t[pos].cnt; Splay(pos, 0); Update(pos); return; }
	Splay(pos, 0);
	int l = ls, r = rs;
	while (t[l].son[1]) l = t[l].son[1];
	while (t[r].son[0]) r = t[r].son[0];
	Splay(l, 0); Splay(r, l);
	t[r].son[0] = 0;
	Update(r); Update(l);
}
int Query_Rnk(int pos, int val) {
	if (!ls && !rs && val != t[pos].val) { Splay(pos, 0); return 0; }
	else if (val == t[pos].val) { int res = t[ls].siz + 1; Splay(pos, 0); return res; }
	else if (val < t[pos].val) return Query_Rnk(ls, val);
	int res = t[ls].siz + t[pos].cnt;
	return Query_Rnk(rs, val) + res;
}
int Query_Kth(int pos, int rank) {
	if (t[ls].siz >= rank && ls) return Query_Kth(ls, rank);
	else if (t[ls].siz + t[pos].cnt >= rank) { Splay(pos, 0); return t[pos].val; }
	else if (rs) return Query_Kth(rs, rank - t[ls].siz - t[pos].cnt);
	Splay(pos, 0); return 0;
}
int Get_Pre(int val) {
	int pos = root, res = root;
	pos = root;
	while (pos) {
		if (t[pos].val < val) res = pos, pos = rs;
		else pos = ls;
	}
	Splay(res, 0);
	return t[res].val;
}
int Get_Nxt(int val) {
	int pos = root, res = root;
	while (pos) {
		if (t[pos].val > val) res = pos, pos = ls;
		else pos = rs;
	}
	Splay(res, 0);
	return t[res].val;
}
int main() {
	Build(); scanf("%d %d", &n, &m);
	for (int i = 1, a; i <= n; i++) scanf("%d", &a), Insert(root, a, 0);
	for (int i = 1, opt, x, last = 0; i <= m; i++) {
		scanf("%d %d", &opt, &x); x ^= last;
		if (opt == 1) Insert(root, x, 0);
		else if (opt == 2) Erase(x);
		else if (opt == 3) {
			Insert(root, x, 0);
			last = Query_Rnk(root, x) - 1;
			Erase(x);
		}
		else if (opt == 4) last = Query_Kth(root, x + 1);
		else if (opt == 5) last = Get_Pre(x);
		else last = Get_Nxt(x);
		if (opt >= 3) ans ^= last;
	}
	printf("%d", ans);
	return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM