前言
splay學了已經很久了,只不過一直沒有總結,鴿了好久來寫一篇總結。
先介紹 splay:亦稱伸展樹,為二叉搜索樹的一種,部分操作能在 \(O( \log n)\) 內完成,如插入、查找、刪除、查詢序列第 \(k\) 大、查詢前綴(比查詢的數小的數中最大的數)、查詢后綴(比查詢的數大的數中最小的數)等操作,甚至能夠實現區間平移。它由 Daniel Sleator 和 Robert Endre Tarjan 在1985年發明的。注:時間復雜度是均攤為 \(O(\log n)\) ,是經過嚴謹的證明的,單個操作可能退化成 \(O(n)\) 。
算法思想
先做一個小小的引入:輸入法中,你經常使用詞語,會在詞條中靠前的位置。實現過程可以使用 splay。
splay 是二叉搜索樹的一種,這里簡單介紹一下二叉搜索樹。
對於一棵二叉樹,滿足樹上任意節點,它的左子樹上任意節點滿足比當前節點的權值小,右子樹上任意節點的權值比當前節點的權值大。則稱這棵樹為二叉搜索樹。
可以利用二叉搜索樹的性質來進行操作,比當前節點的權值小就在左子樹查找,權值大就在右子樹查找。
理想狀態下,若該二叉樹為一顆完全二叉樹,則單次操作時間復雜度為 \(O(\log n)\) 。但這顆二叉樹可能退化成一條鏈,這樣單次時間復雜度為 \(O(n)\) 。
splay 樹在這上面進行了改進,通過不斷改變樹的形態來保證不會退化,均攤時間復雜度為 \(O(\log n)\) 。基本思想是把搜索頻率高的點放在深度小的位置,為了操作方便,可以認為每次操作的點都是頻率高的。常常把操作的點,或是操作區間的兩個端點放在根或根的附近的位置,那么會涉及到旋轉操作。
根據勢能函數分析(我不會),splay 的時間復雜度上限為 \(O((m+n)\log n)\) ,但這個上限是有波動的。
基本操作
建議配合注釋一起使用。
結構體中應包含以下信息:
struct Splay_Node {
int son[2], val, cnt, siz, fa;
//分別是:兩個兒子,權值,副本數,子樹大小,父親節點
#define ls t[pos].son[0] //宏定義左兒子,方便一些
#define rs t[pos].son[1] //右兒子,同上
};
簡單說明一下,副本數為權值為 val 的數的個數。
New
開辟新節點,里面的值隨需求變化,以下是幾個重要的值。
int New(int val, int fa) {
t[++tot].fa = fa, t[tot].cnt = t[tot].siz = 1, t[tot].val = val;
return tot;
}
Build
建立splay樹,將極小值置為根節點,極大值作為根節點的右兒子,滿足二叉搜索樹的性質,代碼:
void Build() {
root = New(-INF, 0); //極小值為根節點
t[root].son[1] = New(INF, root); //極大值為右兒子
Update(root); //更新根節點信息
}
寫這段代碼的主要原因是:使得 splay 的每個節點不會爆掉邊界,否則很容易就 RE 。
Ident
判斷該節點為父節點的左兒子還是右兒子,左兒子為 \(0\) ,右兒子為 \(1\) 。
bool Ident(int pos) { return t[t[pos].fa].son[1] == pos; }
Update
更新子樹大小,還更新節點信息(由需求所定)。
void Update(int pos) {
t[pos].siz = t[ls].siz + t[rs].siz + t[pos].cnt; //子樹大小為左右子樹大小加上自己的副本數
}
Connect
將一對點變為父子關系。
void Connect(int pos, int fa, int flag) {//依次是:子節點,父節點,哪個兒子
t[fa].son[flag] = pos;//將fa的兒子置為pos
t[pos].fa = fa;//將pos的父親置為fa
}
Rotate
既然要把一個點旋轉到根節點,那么就必須先掌握單旋操作,具體分兩個情況討論。
左兒子旋轉至父節點
如上圖,需要進行幾次轉換: \(x\) 的左兒子變為 \(y\) 的右兒子, \(y\) 的右兒子變為\(x\) , \(a\) 的子節點變為 \(y\) 。
那么程序可以寫為:
void Rotate(int pos) {//這里的flag1=0,可以按照上述的三個轉換進行驗證這段程序是對的
int fa = t[pos].fa, grand = t[fa].fa;
int flag1 = Ident(pos), flag2 = Ident(fa);
Connect(pos, grand, flag2);
Connect(t[pos].son[flag1 ^ 1], fa, flag1);
Connect(fa, pos, flag1 ^ 1);
Update(fa); Update(pos);
}
右兒子旋轉至父節點
可以視為上圖的逆操作: \(y\) 的右兒子變為 \(x\) 的左兒子, \(x\) 的左兒子變為\(y\) , \(a\) 的子節點變為 \(x\) 。
那么程序依舊可以寫為:
void Rotate(int pos) {//這里的flag1=1,可以按照上述的三個轉換進行驗證這段程序是對的
int fa = t[pos].fa, grand = t[fa].fa;
int flag1 = Ident(pos), flag2 = Ident(fa);
Connect(pos, grand, flag2);
Connect(t[pos].son[flag1 ^ 1], fa, flag1);
Connect(fa, pos, flag1 ^ 1);
Update(fa); Update(pos);
}
綜上所述,Rotate 操作可以不用判斷左右節點,寫法為上述程序。
Splay
聽名字就知道,這是splay樹的核心操作。
函數 \(splay(pos,to)\) 定義為:將編號為 \(x\) 的節點,旋轉至父親為 \(to\) 的節點(即 \(to\) 的其中一個子節點,且進行 splay 后依然滿足二叉搜索樹的性質)。
顯然有一種方法:對於當前節點 \(pos\) ,不停進行 \(Rotate(pos)\) ,知道 \(pos\) 的父節點為 \(to\) 為止。
但是這並不能使該 splay 樹的形態發生太大的改變。splay 的目的是改變樹的形態,有一種改進的方法:雙旋。順帶說明一下,單旋會被卡成 \(O(nm)\) 。(具體我也不知道怎么卡)
雙旋即一次旋轉兩次,設當前點為 \(x\) ,父親節點為 \(y\) ,爺爺為 \(z\) 。具體分為兩種情況,這里只證明正確性。
x、y、z 形成一條鏈
這種情況先單旋 \(y\) 在單旋 \(x\) 。過程見下圖:
顯然,在上述過程中,嚴謹地滿足了 \(val[x]>val[y]>val[z]\) 。
x、y、z 形成“<”或 “>”
直接進行兩次單旋操作,正確性顯然。
Code
代碼很短,只有三行。
void Splay(int pos, int to) {
for(int fa = t[pos].fa; t[pos].fa != to; Rotate(pos), fa = t[pos].fa)
if(t[fa].fa != to) Ident(pos) == Ident(fa) ? Rotate(fa) : Rotate(pos);
//Ident(pos) == Ident(fa)意味着pos和fa成為了一條鏈的形狀,否則為“<”或“>”。
if(!to) root = pos;//更新根節點,根節點的父親值為0
}
總結
這些是 splay 的基本操作,之后的所有操作都是建立在這些之上的。
引申操作
Find
定義 \(Find(val)\) :查詢權值為 \(val\) 的點的編號,若沒有該點就返回 \(0\) 。
利用 splay 為二叉搜索樹的性質,若 \(val\) 小於當前節點的權值,則在左子樹中查找;若大於則在右子樹中查找。知道找到當前節點的編號為 \(0\) 或當前節點的權值等於 \(val\) 的時候返回改點的下標。
int Find(int pos, int val) {
if(!pos) return 0;//空節點直接返回
if(val == t[pos].val) return pos;//等於就直接返回節點編號
else if(val < t[pos].val) return Find(ls, val);//在左子樹中查找
else return Find(rs, val);//在右子樹中查找
}
Insert
即插入操作, 需要插入權值為 \(val\) 的值。
其思想跟 \(Find\) 函數差不多,利用二叉搜索樹的性質直接就可以找到插入的位置。具體分為兩類:
- 有權值為 \(val\) 的點 \(pos\) ,直接使得副本數加 \(1\) 即可。
- 沒有權值為 \(val\) 的點 \(pos\) ,則開辟一個新的節點權值為 \(val\) 。
注意 \(pos\) 應傳實參,因為若開辟了新的節點,其父節點的對應兒子也需要改變。
void Insert(int &pos, int val, int fa) {//pos為實參
if (!pos) Splay(pos = New(val, fa), 0);
else if (val == t[pos].val) { ++t[pos].siz, ++t[pos].cnt; Splay(pos, 0); }
else if (val < t[pos].val) Insert(ls, val, pos);
else Insert(rs, val, pos);
}
Erase
即刪除操作, \(Erase(val)\) 定義為:刪除所維護的序列中權值為 \(val\) 的一個節點(如果有的話)。
可以先找到權值為 \(val\) 的節點並定義其編號為 \(pos\) ,分兩種情況。
- 若當前節點的副本數大於 \(1\) 時,即 \(t[pos].cnt>1\) 時,可以刪除其中一個副本即可,但並沒有刪除這個節點。
- 否則,則需要刪除該節點。需要先將 \(pos\) splay 到根節點。找到它的前綴的編號 \(l\) 和它的后綴的編號 \(r\) ,則 \(t[l].val\leq val \leq t[r].val\) 。顯然, \((t[l].val,t[r].val)\) 區間內的數只有一個,即 \(pos\) 。將 \(l\) splay 至根節點, \(r\) splay 至 \(l\) 的右兒子,則 \(pos\) 必會在 \(r\) 的左兒子處,因為 \(l\) , \(r\) , \(pos\) 必回滿足二叉搜索樹的性質。然后直接刪除 \(r\) 的左兒子即可。
void Erase(int val) {
int pos = Find(root, val);//找到權值為 val 的點。
if (!pos) return;//沒有改節點直接返回,沒有難倒刪空氣?
if (t[pos].cnt > 1) { --t[pos].siz, --t[pos].cnt; Splay(pos, 0); Update(pos); return; }//對應情況1
Splay(pos, 0);
int l = ls, r = rs;
while (t[l].son[1]) l = t[l].son[1];//找到前綴
while (t[r].son[0]) r = t[r].son[0];//找到后繼
Splay(l, 0); Splay(r, l);//對應情況2
t[r].son[0] = 0;
Update(r); Update(l);
}
這里在提供一種做法,與 \(Find\) 函數的做法類似,可以說是其的升級版。總體框架不變,主要是針對第二種情況,將其旋轉到根節點在進行刪除,這種寫法還是比較常見的。
void Erase(int pos, int val) {
if(!pos) return;
if(val == t[pos].val) {
if(t[pos].cnt > 1) { t[pos].siz--, t[pos].cnt--; Splay(pos, 0); return; }
if(ls) Rotate(ls), Erase(pos, val);//有左兒子跟左兒子交換
else if(rs) Rotate(rs), Erase(pos, val);//有右兒子就跟右兒子交換
else {//沒有兒子就直接刪除,注意必須刪除其父親的對應兒子
int newroot = t[pos].fa;
t[t[pos].fa].son[Ident(pos)] = 0;
Splay(newroot, 0);
}
return;
}
else if(val < t[pos].val) rase(ls, val);
else Erase(rs, val);
}
Query_kth
查詢 \(val\) 在序列是第幾大的樹,即按照從小到大的順序排序后, \(val\) 的排名,沒有 \(val\) 輸出返回 \(-1\)。
代碼使用遞歸實現,考慮對於當前節點 \(pos\) ,比 \(val\) 小的數都在左子樹內,即有 \(t[ls].siz\) 個樹比 \(t[pos].val\) 小。
對於局部解,可以將 \(Querykth(pos,val)\) 函數理解為 \(pos\) 的子樹中,小於 \(val\) 的值有多少。
則可以分為三種情況來討論。
- 當 \(val=t[pos].val\) 時,即找到了該節點,返回比它小的數的個數即可,即左子樹的節點數加 \(1\) 。
- 當 \(val<t[pos].val\) 時, \(val\) 左子樹中,在左子樹中查詢該節點的排名。
- 當 \(val>t[pos].val\) 時, 是最麻煩的部分。 \(val\) 右子樹中,左子樹與當前節點都會為答案做貢獻,先將其統計至答案中,在求出右子樹對於答案的貢獻。
注意,最后的答案是包含了極小值的,所以找到后的答案應該減一,這一部分我寫在了主函數里,所以沒找到會輸出 \(-1\) 。
int Query_kth(int pos, int val) {
if(!pos) return 0;//沒有輸出-1
if(val == t[pos].val) { int res = t[ls].siz + 1; Splay(pos, 0); return res; }//對應情況1
else if(val < t[pos].val) return Query_kth(ls, val);//對於情況2
//下兩行代碼對應情況3
int res = t[ls].siz + t[pos].cnt;//找到后splay維護形態會導致子樹的大小變化,因此先記錄答案
return Query_kth(rs, val) + res;
}
Query_val
查詢區間的第 \(k\) 小的數。
可以看做上一個操作的逆操作吧,若 \(k\) 都大於了區間的所有數的個數,就直接返回極大值。
同樣,對於局部解,可以將 \(Queryval(pos,k)\) 函數理解為 \(pos\) 的子樹中,第 \(k\) 大值為多少。
又可以分為三個情況:
- 當 \(t[ls].siz\geq k\) 時,即所求答案在左子樹,在左邊查詢即可。
- 當 \(t[ls].siz+t[pos].cnt\geq k\) 時, 答案為 \(t[pos].val\) ,因為第 \(t[ls].siz+1\) 小至 \(t[ls].siz+t[pos].cnt\) 的數全部權值都為 \(t[pos].val\) 。
- 否則,答案全部會在右子樹當中,查詢右子樹第 \(k-t[ls].siz-t[pos].cnt\) 大,因為當前節點與左兒子一定比右子樹任何一個數小。
同樣的需要注意,最后的答案是包含了極小值的,同樣這一部分我寫在了主函數里,查詢的時候需要查詢第 \(k+1\) 大的那個數。
int Query_val(int pos, int rank) {
if(!pos) return INF;
if(t[ls].siz >= rank) return Query_val(ls, rank);
else if(t[ls].siz + t[pos].cnt >= rank) { Splay(pos, 0); return t[pos].val; }
return Query_val(rs, rank - t[ls].siz - t[pos].cnt);
}
Get_Pre、Get_Nxt
在 \(Erase\) 操作中提到過,可以使用那樣的做法。
亦可使用在文末的代碼中稍快的做法,與 \(Find\) 函數相似,這里就不多說了。(其實是不想打字了)
也可以參照這段代碼將一些操作寫為非遞歸的寫法,會更快一些。
總結
有些細心的同學可能已經發現了,幾乎每個操作都有 splay 操作來維護當前樹的形態,保證時間復雜度。
C++代碼
只是將上述操作拼起來放在一個代碼里。
說明一下操作的幾種類型:
- 插入 \(x\) 數。
- 刪除 \(x\) 數(若有多個相同的數,因只刪除一個)。
- 查詢 \(x\) 數的排名(排名定義為比當前數小的數的個數 \(+1\) )。
- 查詢排名為 \(x\) 的數。
- 求 \(x\) 的前驅(前驅定義為小於 \(x\),且最大的數)。
- 求 \(x\) 的后繼(后繼定義為大於 \(x\),且最小的數)。
不是特別長,實現的方法也並不困難,打的時候必須得注意,完整沒附上注釋的代碼:
#include <cstdio>
#define INF 0x3f3f3f3f
#define Ident(pos) ( t[t[pos].fa].son[1] == pos )
const int MAXN = 1e5 + 5;
struct Splay_Tree {
int son[2], val, cnt, siz, fa;
#define ls t[pos].son[0]
#define rs t[pos].son[1]
};
int root, tot, q;
Splay_Tree t[MAXN];
int New(int val, int fa) {
t[++tot].fa = fa, t[tot].cnt = t[tot].siz = 1, t[tot].val = val;
return tot;
}
void Update(int pos) { t[pos].siz = t[ls].siz + t[rs].siz + t[pos].cnt; }
void Build() { root = New(-INF, 0); t[root].son[1] = New(INF, root); Update(root); }
void Connect(int pos, int fa, int flag) { t[fa].son[flag] = pos, t[pos].fa = fa; }
void Rotate(int pos) {
int fa = t[pos].fa, grand = t[fa].fa;
int flag1 = Ident(pos), flag2 = Ident(fa);
Connect(pos, grand, flag2);
Connect(t[pos].son[flag1 ^ 1], fa, flag1);
Connect(fa, pos, flag1 ^ 1);
Update(fa); Update(pos);
}
void Splay(int pos, int to) {
for (int fa = t[pos].fa; t[pos].fa != to; Rotate(pos), fa = t[pos].fa)
if (t[fa].fa != to) Ident(pos) == Ident(fa) ? Rotate(fa) : Rotate(pos);
if (!to) root = pos;
}
int Find(int pos, int val) {
if (!pos) return 0;
if (val == t[pos].val) return pos;
else if (val < t[pos].val) return Find(ls, val);
return Find(rs, val);
}
void Insert(int &pos, int val, int fa) {
if (!pos) Splay(pos = New(val, fa), 0);
else if (val == t[pos].val) { ++t[pos].siz, ++t[pos].cnt; Splay(pos, 0); }
else if (val < t[pos].val) Insert(ls, val, pos);
else Insert(rs, val, pos);
}
void Erase(int val) {
int pos = Find(root, val);
if (!pos) return;
if (t[pos].cnt > 1) { --t[pos].siz, --t[pos].cnt; Splay(pos, 0); Update(pos); return; }
Splay(pos, 0);
int l = ls, r = rs;
while (t[l].son[1]) l = t[l].son[1];
while (t[r].son[0]) r = t[r].son[0];
Splay(l, 0); Splay(r, l);
t[r].son[0] = 0;
Update(r); Update(l);
}
int Query_Rnk(int pos, int val) {
if (!ls && !rs && val != t[pos].val) { Splay(pos, 0); return 0; }
else if (val == t[pos].val) { int res = t[ls].siz + 1; Splay(pos, 0); return res; }
else if (val < t[pos].val) return Query_Rnk(ls, val);
int res = t[ls].siz + t[pos].cnt;
return Query_Rnk(rs, val) + res;
}
int Query_Kth(int pos, int rank) {
if (t[ls].siz >= rank && ls) return Query_Kth(ls, rank);
else if (t[ls].siz + t[pos].cnt >= rank) { Splay(pos, 0); return t[pos].val; }
else if (rs) return Query_Kth(rs, rank - t[ls].siz - t[pos].cnt);
Splay(pos, 0); return 0;
}
int Get_Pre(int val) {
int pos = root, res = root;
pos = root;
while (pos) {
if (t[pos].val < val) res = pos, pos = rs;
else pos = ls;
}
Splay(res, 0);
return t[res].val;
}
int Get_Nxt(int val) {
int pos = root, res = root;
while (pos) {
if (t[pos].val > val) res = pos, pos = ls;
else pos = rs;
}
Splay(res, 0);
return t[res].val;
}
int main() {
Build(); scanf("%d", &q);
for (int i = 1, opt, x; i <= q; i++) {
scanf("%d %d", &opt, &x);
if (opt == 1) Insert(root, x, 0);
else if (opt == 2) Erase(x);
else if (opt == 3) printf("%d\n", Query_Rnk(root, x) - 1);
else if (opt == 4) printf("%d\n", Query_Kth(root, x + 1));
else if (opt == 5) printf("%d\n", Get_Pre(x));
else printf("%d\n", Get_Nxt(x));
}
return 0;
}
補充
原本數據有億點水,錯的代碼都能過。
#include <cstdio>
#define INF 2147483647
#define Ident(pos) ( t[t[pos].fa].son[1] == pos )
const int MAXN = 2e6 + 5;
struct Splay_Tree {
int son[2], val, cnt, siz, fa;
#define ls t[pos].son[0]
#define rs t[pos].son[1]
};
int root, tot, ans, n, m;
Splay_Tree t[MAXN];
int New(int val, int fa) {
t[++tot].fa = fa, t[tot].cnt = t[tot].siz = 1, t[tot].val = val;
return tot;
}
void Update(int pos) { t[pos].siz = t[ls].siz + t[rs].siz + t[pos].cnt; }
void Build() { root = New(-INF, 0); t[root].son[1] = New(INF, root); Update(root); }
void Connect(int pos, int fa, int flag) { t[fa].son[flag] = pos, t[pos].fa = fa; }
void Rotate(int pos) {
int fa = t[pos].fa, grand = t[fa].fa;
int flag1 = Ident(pos), flag2 = Ident(fa);
Connect(pos, grand, flag2);
Connect(t[pos].son[flag1 ^ 1], fa, flag1);
Connect(fa, pos, flag1 ^ 1);
Update(fa); Update(pos);
}
void Splay(int pos, int to) {
for (int fa = t[pos].fa; t[pos].fa != to; Rotate(pos), fa = t[pos].fa)
if (t[fa].fa != to) Ident(pos) == Ident(fa) ? Rotate(fa) : Rotate(pos);
if (!to) root = pos;
}
int Find(int pos, int val) {
if (!pos) return 0;
if (val == t[pos].val) return pos;
else if (val < t[pos].val) return Find(ls, val);
return Find(rs, val);
}
void Insert(int &pos, int val, int fa) {
if (!pos) Splay(pos = New(val, fa), 0);
else if (val == t[pos].val) { ++t[pos].siz, ++t[pos].cnt; Splay(pos, 0); }
else if (val < t[pos].val) Insert(ls, val, pos);
else Insert(rs, val, pos);
}
void Erase(int val) {
int pos = Find(root, val);
if (!pos) return;
if (t[pos].cnt > 1) { --t[pos].siz, --t[pos].cnt; Splay(pos, 0); Update(pos); return; }
Splay(pos, 0);
int l = ls, r = rs;
while (t[l].son[1]) l = t[l].son[1];
while (t[r].son[0]) r = t[r].son[0];
Splay(l, 0); Splay(r, l);
t[r].son[0] = 0;
Update(r); Update(l);
}
int Query_Rnk(int pos, int val) {
if (!ls && !rs && val != t[pos].val) { Splay(pos, 0); return 0; }
else if (val == t[pos].val) { int res = t[ls].siz + 1; Splay(pos, 0); return res; }
else if (val < t[pos].val) return Query_Rnk(ls, val);
int res = t[ls].siz + t[pos].cnt;
return Query_Rnk(rs, val) + res;
}
int Query_Kth(int pos, int rank) {
if (t[ls].siz >= rank && ls) return Query_Kth(ls, rank);
else if (t[ls].siz + t[pos].cnt >= rank) { Splay(pos, 0); return t[pos].val; }
else if (rs) return Query_Kth(rs, rank - t[ls].siz - t[pos].cnt);
Splay(pos, 0); return 0;
}
int Get_Pre(int val) {
int pos = root, res = root;
pos = root;
while (pos) {
if (t[pos].val < val) res = pos, pos = rs;
else pos = ls;
}
Splay(res, 0);
return t[res].val;
}
int Get_Nxt(int val) {
int pos = root, res = root;
while (pos) {
if (t[pos].val > val) res = pos, pos = ls;
else pos = rs;
}
Splay(res, 0);
return t[res].val;
}
int main() {
Build(); scanf("%d %d", &n, &m);
for (int i = 1, a; i <= n; i++) scanf("%d", &a), Insert(root, a, 0);
for (int i = 1, opt, x, last = 0; i <= m; i++) {
scanf("%d %d", &opt, &x); x ^= last;
if (opt == 1) Insert(root, x, 0);
else if (opt == 2) Erase(x);
else if (opt == 3) {
Insert(root, x, 0);
last = Query_Rnk(root, x) - 1;
Erase(x);
}
else if (opt == 4) last = Query_Kth(root, x + 1);
else if (opt == 5) last = Get_Pre(x);
else last = Get_Nxt(x);
if (opt >= 3) ans ^= last;
}
printf("%d", ans);
return 0;
}