Treap 学习笔记
Treap 简介
Treap 是一种二叉查找树。它的结构同时满足二叉查找树(Tree)与堆(Heap)的性质,因此得名。Treap的原理是为每一个节点赋一个随机值使其满足堆的性质,保证了树高期望 O(log2n) ,从而保证了时间复杂度。
Treap 是一种高效的平衡树算法,在常数大小与代码复杂度上好于 Splay。
Treap 的基本操作
现在以 BZOJ 3224 普通平衡树为模板题,详细讨论 Treap 的基本操作。
1.基本结构
在一般情况下,Treap 的节点需要存储它的左右儿子,子树大小,节点中相同元素的数量(如果没有可以默认为1),自身信息及随机数的值。
struct node{ int l, r, v, siz, rnd, ct; }d[1000005];
其中 l
为左儿子节点编号, r
为右儿子节点编号, v
为节点数值, siz
为子树大小, rnd
为节点的随机值, ct
为该节点数值的出现次数(目的为将所有数值相同的点合为一个)。
2.关于随机值
随机值由 rand()
函数生成, 考虑到 <cstdlib>
库中的 rand()
速度较慢,所以在卡常数的时候建议手写 rand()
函数。
inline int rand(){ static int seed = 2333; return seed = (int)((((seed ^ 998244353) + 19260817ll) * 19890604ll) % 1000000007); }
其中 seed
为随机种子,可以随便填写。
3.节点信息更新
节点信息更新由 update()
函数实现。在每次产生节点关系的修改后,需要更新节点信息(最基本的子树大小,以及你要维护的其他内容)。
时间复杂度 O(1) 。
inline void update(int k){ d[k].siz = d[lc].siz + d[rc].siz + d[k].ct; }
4.「重要」左旋与右旋
左旋与右旋是 Treap 的核心操作,也是 Treap 动态保持树的深度的关键,其目的为维护 Treap 堆的性质。
下面的图片可以让你更好的理解左旋与右旋:
下面具体介绍左旋与右旋操作。左旋与右旋均为变更操作节点与其两个儿子的相对位置的操作。
「左旋」为将作儿子节点代替根节点的位置, 根节点相应的成为左儿子节点的右儿子(满足二叉搜索树的性质)。相应的,之前左儿子节点的右儿子应转移至之前根节点的左儿子。此时,只有之前的根节点与左儿子节点的 siz
发生了变化。所以要 update()
这两个节点。
「右旋」类似于「左旋」,将左右关系相反即可。
时间复杂度 O(1) 。
void rturn(int &k){ //右旋 int t = d[k].l; d[k].l = d[t].r; d[t].r = k; d[t].siz = d[k].siz; update(k); k = t; } void lturn(int &k){ //左旋 int t = d[k].r; d[k].r = d[t].l; d[t].l = k; d[t].siz = d[k].siz; update(k); k = t; }
5.节点的插入与删除
节点的插入与删除是 Treap 的基本功能之一。
「节点的插入」是一个递归的过程,我们从根节点开始,逐个判断当前节点的值与插入值的大小关系。如果插入值小于当前节点值,则递归至左儿子;大于则递归至右儿子;
相等则直接在把当前节点数值的出现次数 +1 ,跳出循环即可。如果当前访问到了一个空节点,则初始化新节点,将其加入到 Treap 的当前位置。
「节点的删除」同样是一个递归的过程,不过需要讨论多种情况:
如果插入值小于当前节点值,则递归至左儿子;大于则递归至右儿子。
如果插入值等于当前节点值:
若当前节点数值的出现次数大于 1 ,则减一;
若当前节点数值的出现次数等于于 1 :
若当前节点没有左儿子与右儿子,则直接删除该节点(置 0);
若当前节点没有左儿子或右儿子,则将左儿子或右儿子替代该节点;
若当前节点有左儿子与右儿子,则不断旋转 当前节点,并走到当前节点新的对应位置,直到没有左儿子或右儿子为止。
时间复杂度均为 O(log2n) 。
具体实现代码如下:
1 void ins(int &p,int x) 2 { 3 if (p==0) 4 { 5 p=++sz; 6 tr[p].siz=tr[p].ct=1,tr[p].val=x,tr[p].rnd=rand(); 7 return; 8 } 9 tr[p].siz++; 10 if (tr[p].val==x) tr[p].ct++; 11 else if (x>tr[p].val) 12 { 13 ins(tr[p].r,x); 14 if (tr[rs].rnd<tr[p].rnd) lturn(p); 15 }else 16 { 17 ins(tr[p].l,x); 18 if (tr[ls].rnd<tr[p].rnd) rturn(p); 19 } 20 } 21 void del(int &p,int x) 22 { 23 if (p==0) return; 24 if (tr[p].val==x) 25 { 26 if (tr[p].ct>1) tr[p].ct--,tr[p].siz--;//如果有多个直接减一即可。 27 else 28 { 29 if (ls==0||rs==0) p=ls+rs;//单节点或者空的话直接儿子移上来或者删去即可。 30 else if (tr[ls].rnd<tr[rs].rnd) rturn(p),del(p,x); 31 else lturn(p),del(p,x); 32 } 33 } 34 else if (x>tr[p].val) tr[p].siz--,del(rs,x); 35 else tr[p].siz--,del(ls,x); 36 }
6.查询数x的排名
查询数x的排名可以利用在二叉搜索树上的相同方法实现。
具体思路为根据递归找到当前节点,并记录小于这个节点的节点的数量(左子树) 。
时间复杂度 O(log2n) 。
代码实现如下:
1 int find_pm(int p,int x) 2 { 3 if (p==0) return 0; 4 if (tr[p].val==x) return tr[ls].siz+1; 5 if (x>tr[p].val) return tr[ls].siz+tr[p].ct+find_pm(rs,x); 6 else return find_pm(ls,x); 7 }
7.查询排名为x的数
查询排名为x的数可以利用在二叉搜索树上的相同方法实现。
具体思路为根据当前x来判断该数在左子树还是右子树 。
时间复杂度 O(log2n) 。
代码实现如下:
1 int find_hj(int p,int x) 2 { 3 if (p==0) return inf; 4 if (tr[p].val<=x) return find_hj(rs,x); 5 else return min(tr[p].val,find_hj(ls,x)); 6 }
8.查询数的前驱与后继
查询数的前驱与后继同样可以递归实现。查前驱即为递归当前数,走到小于等于x的节点,并记录其中最大的。后继同理。
时间复杂度 O(log2n) 。
代码实现如下:
1 int find_qq(int p,int x) 2 { 3 if (p==0) return -inf; 4 if (tr[p].val<x) return max(tr[p].val,find_qq(rs,x)); 5 else if (tr[p].val>=x) return find_qq(ls,x); 6 } 7 int find_hj(int p,int x) 8 { 9 if (p==0) return inf; 10 if (tr[p].val<=x) return find_hj(rs,x); 11 else return min(tr[p].val,find_hj(ls,x)); 12 }
总体合并的模板由bzoj3224这题,这是一道模板题,裸的。
1 #include<cstring> 2 #include<cmath> 3 #include<iostream> 4 #include<algorithm> 5 #include<cstdio> 6 7 #define ls tr[p].l 8 #define rs tr[p].r 9 #define N 100007 10 #define inf 2000000010 11 using namespace std; 12 inline int read() 13 { 14 int x=0,f=1;char ch=getchar(); 15 while(ch<'0'||ch>'9'){if (ch=='-')f=-1;ch=getchar();} 16 while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} 17 return x*f; 18 } 19 20 int n,sz,rt,ans; 21 struct Node 22 { 23 int l,r,val,siz,rnd,ct;//记录左儿子,右儿子,点值,该子树大小,随机的值,该点值出现的次数。 24 }tr[N];//最多多少个节点,就开多少空间 25 26 inline int rand(){ 27 static int seed = 2333; 28 return seed = (int)((((seed ^ 998244353) + 19260817ll) * 19890604ll) % 1000000007); 29 } 30 inline void update(int p) 31 { 32 tr[p].siz=tr[ls].siz+tr[rs].siz+tr[p].ct; 33 } 34 void lturn(int &p) 35 { 36 int t=tr[p].r;tr[p].r=tr[t].l;tr[t].l=p; 37 tr[t].siz=tr[p].siz;update(p);p=t; 38 } 39 void rturn(int &p) 40 { 41 int t=tr[p].l;tr[p].l=tr[t].r;tr[t].r=p; 42 tr[t].siz=tr[p].siz;update(p);p=t; 43 } 44 void ins(int &p,int x) 45 { 46 if (p==0) 47 { 48 p=++sz; 49 tr[p].siz=tr[p].ct=1,tr[p].val=x,tr[p].rnd=rand(); 50 return; 51 } 52 tr[p].siz++; 53 if (tr[p].val==x) tr[p].ct++; 54 else if (x>tr[p].val) 55 { 56 ins(tr[p].r,x); 57 if (tr[rs].rnd<tr[p].rnd) lturn(p); 58 }else 59 { 60 ins(tr[p].l,x); 61 if (tr[ls].rnd<tr[p].rnd) rturn(p); 62 } 63 } 64 void del(int &p,int x) 65 { 66 if (p==0) return; 67 if (tr[p].val==x) 68 { 69 if (tr[p].ct>1) tr[p].ct--,tr[p].siz--;//如果有多个直接减一即可。 70 else 71 { 72 if (ls==0||rs==0) p=ls+rs;//单节点或者空的话直接儿子移上来或者删去即可。 73 else if (tr[ls].rnd<tr[rs].rnd) rturn(p),del(p,x); 74 else lturn(p),del(p,x); 75 } 76 } 77 else if (x>tr[p].val) tr[p].siz--,del(rs,x); 78 else tr[p].siz--,del(ls,x); 79 } 80 int find_pm(int p,int x) 81 { 82 if (p==0) return 0; 83 if (tr[p].val==x) return tr[ls].siz+1; 84 if (x>tr[p].val) return tr[ls].siz+tr[p].ct+find_pm(rs,x); 85 else return find_pm(ls,x); 86 } 87 int find_sz(int p,int x) 88 { 89 if (p==0) return 0; 90 if (x<=tr[ls].siz) return find_sz(ls,x); 91 x-=tr[ls].siz; 92 if (x<=tr[p].ct) return tr[p].val; 93 x-=tr[p].ct; 94 return find_sz(rs,x); 95 } 96 int find_qq(int p,int x) 97 { 98 if (p==0) return -inf; 99 if (tr[p].val<x) return max(tr[p].val,find_qq(rs,x)); 100 else if (tr[p].val>=x) return find_qq(ls,x); 101 } 102 int find_hj(int p,int x) 103 { 104 if (p==0) return inf; 105 if (tr[p].val<=x) return find_hj(rs,x); 106 else return min(tr[p].val,find_hj(ls,x)); 107 } 108 int main() 109 { 110 n=read(); 111 for (int i=1;i<=n;i++) 112 { 113 int flag=read(),x=read(); 114 if (flag==1) ins(rt,x); 115 if (flag==2) del(rt,x); 116 if (flag==3) printf("%d\n",find_pm(rt,x));//寻找x的排名 117 if (flag==4) printf("%d\n",find_sz(rt,x));//寻找排名为x的数字 118 if (flag==5) printf("%d\n",find_qq(rt,x)); 119 if (flag==6) printf("%d\n",find_hj(rt,x)); 120 } 121 }
特别感谢:http://blog.csdn.net/infinity_edge/article/details/78607724