引言
樹套樹,顧名思義,就是要將兩種或多種樹形數據結構結合起來,解決一些單獨無法解決的問題。
如果說要解決區間上的問題,如最大值,區間修改等,肯定會想到線段樹。
但是線段樹不能查詢第k大,不能查詢一個數在區間的排名,自然也不能查詢前驅和后繼。
平衡樹可以解決查詢排名、前驅、后繼等問題,但其不能限定區間。
文藝平衡樹中有操作可以把區間鎖定在一個結點的子樹,問題是只能通過翻轉左右子樹,來實現區間翻轉。
既然單獨無法解決這個問題,那就將兩種樹形數據結構結合起來。
原理
很多人都對樹套樹望而生畏,包括我。。。
以前只知道通過樹套樹可以解決的問題,但沒有敲過
經常聽到隊友說幾個線段樹,再用一個主席樹維護什么什么的,但其實原理不難,只要懂樹套樹中的這兩種樹。
舉個例子:

如圖這是個線段樹
假設這個序列是: 5 2 3 4 5 7 8 9 3 1 (隨便寫的)
現在我要查2-7區間中第5個數即5在這個區間排第幾小,
[3-7]區間即:[3],[4-5],[6-7]
第幾小即計算有多少個比它小,然后加一
[3]:1個
[4-5]:1個
[6-7]:0個
所以他是第3小的。
將每個子區間得到的答案求和利用的是線段樹
而中間每個區間查詢有多少比它小利用的是平衡樹Splay
即線段樹的每個結點建立一個Splay
有人會懷疑空間復雜度不夠,如果把Splay封裝,每個Splay都是\(N\)的大小必然不夠,我們不需要事先開辟那么多空間來建Splay
代碼中是:(開局一個root,然后記錄每個線段樹結點的root就行了)
void build(int p,int l,int r){
t[p].l = l,t[p].r = r;
//線段樹每個結點建立一個splay
sp.ins(t[p].rt,-inf);
sp.ins(t[p].rt,inf);
for(int i = l;i <= r;++i){
sp.ins(t[p].rt,arr[i]);
}
if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
int mid = (l+r) >> 1;
build(p<<1,l,mid);
build(p<<1|1,mid + 1,r);
pushUp(p);
}
這里有個問題,root會發生變化,所以線段樹結點中定義的root並不是一成不變的,這需要用到引用,即傳地址。
還有就是要插入兩個無窮大結點,來解決不存在的情況。
應用-模板題
- 查詢k在區間內的排名
- 查詢區間內排名為k的值
- 修改某一位值上的數值
- 查詢k在區間內的前驅(前驅定義為嚴格小於x,且最大的數,若不存在輸出-2147483647)
- 查詢k在區間內的后繼(后繼定義為嚴格大於x,且最小的數,若不存在輸出2147483647)
先復制上封裝好的Splay
struct Splay {
int get(int x) {return s[s[x].fa].ch[1] == x;}
void Clear(int x) {
s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
}
void maintain(int x){
s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
}
void Rorate(int x){
int y = s[x].fa, z = s[y].fa, chk = get(x);
s[y].ch[chk] = s[x].ch[chk ^ 1];
s[s[x].ch[chk ^ 1]].fa = y;
s[y].fa = x;
s[x].ch[chk ^ 1] = y;
s[x].fa =z;
if(z) s[z].ch[s[z].ch[1] == y] = x;
maintain(y);
maintain(x);
}
void splay(int &rt,int x,int y){
for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
}
if(y==0) rt = x;
}
void ins(int &root ,int val){
if(!root) {
root = ++tot;
s[root].val = val;
s[root].cnt++;
maintain(root);
return ;
}
int f = 0, x = root;
while(true){
if(s[x].val == val){
s[x].cnt ++;
maintain(x);
maintain(f);
splay(root,x,0);
return;
}
f = x;
x = s[x].ch[s[x].val < val];
if(!x) {
s[++tot].val = val;
s[tot].cnt = 1;
s[tot].fa = f;
s[f].ch[s[f].val < val] = tot;
maintain(tot);
maintain(f);
splay(root,tot,0);
return ;
}
}
}
inline int Find(int &rt,int k) {
int res = 0,now = rt;
while(true) {
if(k<s[now].val) {
now = s[now].ch[0];
}else {
//否則加上右子樹的個數
res += s[s[now].ch[0]].sz;
//中序遍歷,如果找到這個節點返回res+1
if(k == s[now].val) {
splay(rt,now,0);
return res + 1;
}
res += s[now].cnt;
now = s[now].ch[1];
}
}
}
int getPre(int rt){
int now = s[rt].ch[0];
while (s[now].ch[1]) now = s[now].ch[1];
return now;
}
int getNxt(int rt){
int now = s[rt].ch[1];
while (s[now].ch[0]) now = s[now].ch[0];
return now;
}
inline void del(int &rt,int k){
Find(rt,k);//先讓該點成為根節點
if(s[rt].cnt > 1) {//如果大於1,不需要刪除節點
s[rt].cnt--;
maintain(rt);
return;
}
//如果只有一個點
if(!s[rt].ch[0] && !s[rt].ch[1]){
Clear(rt);
rt = 0;
return;
}
//沒有左兒子,讓右兒子成為根節點
if(!s[rt].ch[0]){
int tmp = rt;
rt = s[rt].ch[1];
s[rt].fa=0;
Clear(tmp);
return;
}
//沒有右兒子,讓左兒子成為根節點
if(!s[rt].ch[1]){
int tmp = rt;
rt = s[rt].ch[0];
s[rt].fa = 0;
Clear(tmp);
return;
}
//有左右兒子,讓前驅成為根節點
int x = getPre(rt) , now = rt;
splay(rt,x,0);
s[s[now].ch[1]].fa = x;
s[x].ch[1] = s[now].ch[1];
Clear(now);
maintain(rt);
}
}sp;
- 問題1之前提到了,就是在Splay中插入這個結點,然后返回這個結點的左兒子的Size就行,記得減去無窮大的那個點。
int query_order(int p,int l,int r,int val){
//查詢順序,就是查有多少個比他小
if(l <= t[p].l && t[p].r <= r){
sp.ins(t[p].rt,val);
int res = s[s[t[p].rt].ch[0]].sz-1;
sp.del(t[p].rt,val);
return res;
}
int mid = (t[p].l + t[p].r) >> 1,res = 0;
if(l <= mid) res += query_order(p << 1,l,r,val);
if(mid < r) res += query_order(p << 1|1,l,r,val);
return res;
}
- 問題2求排名k的值,這需要用到二分,二分check函數就是問題1的
query_order,在區間權值范圍內二分,權值越大排名越大,就是在單調遞增區間中查詢小於k的數的最大值(因為有一個無窮小結點,所以不能小於等於)。二分模板也很明顯:
int query_number(int L,int R,int val){
int l = 1,r = getMax(1,L,R) ,mid,tmp;
while(l < r){
mid = (l + r + 1)>>1;
tmp = query_order(1,L,R,mid);
if(tmp < val){
l = mid;
}else{
r = mid - 1;
}
}
return l;
}
- 問題3是修改,這個不難,這個點所在的所有線段樹結點都要刪除該點在Splay樹上的結點,然后加入新值。
void modify(int p,int pos,int val){
sp.del(t[p].rt,arr[pos]);
sp.ins(t[p].rt,val);
if(t[p].l == t[p].r){
t[p].mx = val;
t[p].mn = val;
arr[pos] = val;
return;
}
int mid = (t[p].l + t[p].r) >> 1;
if(pos <= mid) modify(p << 1,pos,val);
if(pos > mid) modify(p << 1 | 1,pos,val);
pushUp(p);
}
- 問題4,查詢前驅,即查詢每個線段樹區間最大的比該數小的數,最后取個最大值。5同理
int query_Pre(int p,int l,int r,int val){
if(l <= t[p].l && r >= t[p].r){
sp.ins(t[p].rt,val);
int res = s[sp.getPre(t[p].rt)].val;
sp.del(t[p].rt,val);
return res;
}
int res = -inf,mid = (t[p].l + t[p].r) >> 1;
if(l <= mid) res = max(res,query_Pre(p << 1,l,r,val));
if(r > mid) res = max(res,query_Pre(p << 1|1,l,r,val));
return res;
}
int query_Nxt(int p,int l,int r,int val){
if(l <= t[p].l && r >= t[p].r){
sp.ins(t[p].rt,val);
int res = s[sp.getNxt(t[p].rt)].val;
sp.del(t[p].rt,val);
return res;
}
int res = inf,mid = (t[p].l + t[p].r) >> 1;
if(l <= mid) res = min(res,query_Nxt(p << 1,l,r,val));
if(r > mid) res = min(res,query_Nxt(p << 1|1,l,r,val));
return res;
}
- 中途為了優化二分(也沒什么用),還加了線段樹查詢最大值和最小值的
int getMax(int p,int l,int r){
if(l <= t[p].l && t[p].r <=r) return t[p].mx;
int mid = (t[p].l + t[p].r) >> 1,res = -inf;
if(l <= mid) res = max(res,getMax(p << 1,l,r));
if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
return res;
}
int getMin(int p,int l,int r){
if(l <= t[p].l && t[p].r <= r) return t[p].mx;
int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
if(l <= mid) res = min(res,getMin(p << 1,l,r));
if(mid < r) res = min(res,getMin(p << 1|1,l,r));
return res;
}
完整代碼
#pragma GCC optimize(2)
#pragma GCC optimize(3,"Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1e7+7;
const int inf = 2147483647;
int tot;//節點個數
struct node {
int fa;//父親節點
int ch[2];//子節點
int val;//權值
int sz;//子樹大小
int cnt;
}s[N];
struct Tree{
int rt,l,r,mx,mn;
}t[N];
int arr[N];
struct Splay {
int get(int x) {return s[s[x].fa].ch[1] == x;}
void Clear(int x) {
s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
}
void maintain(int x){
s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
}
void Rorate(int x){
int y = s[x].fa, z = s[y].fa, chk = get(x);
s[y].ch[chk] = s[x].ch[chk ^ 1];
s[s[x].ch[chk ^ 1]].fa = y;
s[y].fa = x;
s[x].ch[chk ^ 1] = y;
s[x].fa =z;
if(z) s[z].ch[s[z].ch[1] == y] = x;
maintain(y);
maintain(x);
}
void splay(int &rt,int x,int y){
for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
}
if(y==0) rt = x;
}
void ins(int &root ,int val){
if(!root) {
root = ++tot;
s[root].val = val;
s[root].cnt++;
maintain(root);
return ;
}
int f = 0, x = root;
while(true){
if(s[x].val == val){
s[x].cnt ++;
maintain(x);
maintain(f);
splay(root,x,0);
return;
}
f = x;
x = s[x].ch[s[x].val < val];
if(!x) {
s[++tot].val = val;
s[tot].cnt = 1;
s[tot].fa = f;
s[f].ch[s[f].val < val] = tot;
maintain(tot);
maintain(f);
splay(root,tot,0);
return ;
}
}
}
inline int Find(int &rt,int k) {
int res = 0,now = rt;
while(true) {
if(k<s[now].val) {
now = s[now].ch[0];
}else {
//否則加上右子樹的個數
res += s[s[now].ch[0]].sz;
//中序遍歷,如果找到這個節點返回res+1
if(k == s[now].val) {
splay(rt,now,0);
return res + 1;
}
res += s[now].cnt;
now = s[now].ch[1];
}
}
}
int getPre(int rt){
int now = s[rt].ch[0];
while (s[now].ch[1]) now = s[now].ch[1];
return now;
}
int getNxt(int rt){
int now = s[rt].ch[1];
while (s[now].ch[0]) now = s[now].ch[0];
return now;
}
inline void del(int &rt,int k){
Find(rt,k);//先讓該點成為根節點
if(s[rt].cnt > 1) {//如果大於1,不需要刪除節點
s[rt].cnt--;
maintain(rt);
return;
}
//如果只有一個點
if(!s[rt].ch[0] && !s[rt].ch[1]){
Clear(rt);
rt = 0;
return;
}
//沒有左兒子,讓右兒子成為根節點
if(!s[rt].ch[0]){
int tmp = rt;
rt = s[rt].ch[1];
s[rt].fa=0;
Clear(tmp);
return;
}
//沒有右兒子,讓左兒子成為根節點
if(!s[rt].ch[1]){
int tmp = rt;
rt = s[rt].ch[0];
s[rt].fa = 0;
Clear(tmp);
return;
}
//有左右兒子,讓前驅成為根節點
int x = getPre(rt) , now = rt;
splay(rt,x,0);
s[s[now].ch[1]].fa = x;
s[x].ch[1] = s[now].ch[1];
Clear(now);
maintain(rt);
}
}sp;
void pushUp(int x){
t[x].mx = max(t[x<<1].mx,t[x<<1|1].mx);
t[x].mn = min(t[x<<1].mn,t[x<<1|1].mn);
}
void build(int p,int l,int r){
t[p].l = l,t[p].r = r;
//線段樹每個結點建立一個splay
sp.ins(t[p].rt,-inf);
sp.ins(t[p].rt,inf);
for(int i = l;i <= r;++i){
sp.ins(t[p].rt,arr[i]);
}
if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
int mid = (l+r) >> 1;
build(p<<1,l,mid);
build(p<<1|1,mid + 1,r);
pushUp(p);
}
int getMax(int p,int l,int r){
if(l <= t[p].l && t[p].r <=r) return t[p].mx;
int mid = (t[p].l + t[p].r) >> 1,res = -inf;
if(l <= mid) res = max(res,getMax(p << 1,l,r));
if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
return res;
}
int getMin(int p,int l,int r){
if(l <= t[p].l && t[p].r <= r) return t[p].mx;
int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
if(l <= mid) res = min(res,getMin(p << 1,l,r));
if(mid < r) res = min(res,getMin(p << 1|1,l,r));
return res;
}
int query_order(int p,int l,int r,int val){
//查詢順序,就是查有多少個比他小
if(l <= t[p].l && t[p].r <= r){
sp.ins(t[p].rt,val);
int res = s[s[t[p].rt].ch[0]].sz-1;
sp.del(t[p].rt,val);
return res;
}
int mid = (t[p].l + t[p].r) >> 1,res = 0;
if(l <= mid) res += query_order(p << 1,l,r,val);
if(mid < r) res += query_order(p << 1|1,l,r,val);
return res;
}
void modify(int p,int pos,int val){
sp.del(t[p].rt,arr[pos]);
sp.ins(t[p].rt,val);
if(t[p].l == t[p].r){
t[p].mx = val;
t[p].mn = val;
arr[pos] = val;
return;
}
int mid = (t[p].l + t[p].r) >> 1;
if(pos <= mid) modify(p << 1,pos,val);
if(pos > mid) modify(p << 1 | 1,pos,val);
pushUp(p);
}
int query_Pre(int p,int l,int r,int val){
if(l <= t[p].l && r >= t[p].r){
sp.ins(t[p].rt,val);
int res = s[sp.getPre(t[p].rt)].val;
sp.del(t[p].rt,val);
return res;
}
int res = -inf,mid = (t[p].l + t[p].r) >> 1;
if(l <= mid) res = max(res,query_Pre(p << 1,l,r,val));
if(r > mid) res = max(res,query_Pre(p << 1|1,l,r,val));
return res;
}
int query_Nxt(int p,int l,int r,int val){
if(l <= t[p].l && r >= t[p].r){
sp.ins(t[p].rt,val);
int res = s[sp.getNxt(t[p].rt)].val;
sp.del(t[p].rt,val);
return res;
}
int res = inf,mid = (t[p].l + t[p].r) >> 1;
if(l <= mid) res = min(res,query_Nxt(p << 1,l,r,val));
if(r > mid) res = min(res,query_Nxt(p << 1|1,l,r,val));
return res;
}
int query_number(int L,int R,int val){
int l = 1,r = getMax(1,L,R) ,mid,tmp;
while(l < r){
mid = (l + r + 1)>>1;
tmp = query_order(1,L,R,mid);
if(tmp < val){
l = mid;
}else{
r = mid - 1;
}
}
return l;
}
int main(){
int n,q,op,l,r,pos;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;++i) scanf("%d",&arr[i]);
build(1,1,n);
while(q--){
scanf("%d",&op);
if(op == 1){
scanf("%d%d%d",&l,&r,&pos);
printf("%d\n",query_order(1,l,r,pos)+1);
}else if(op == 2){
scanf("%d%d%d",&l,&r,&pos);
printf("%d\n",query_number(l,r,pos));
}else if(op == 3){
scanf("%d%d",&l,&pos);
modify(1,l,pos);
}else if(op == 4){
scanf("%d%d%d",&l,&r,&pos);
printf("%d\n",query_Pre(1,l,r,pos));
}else if(op == 5){
scanf("%d%d%d",&l,&r,&pos);
printf("%d\n",query_Nxt(1,l,r,pos));
}
}
return 0;
}
代碼不加O2優化會超時,如果要優化的話,可以加個輸入輸出掛。
后記
博客兩周年快樂。
這是第一篇博客https://www.cnblogs.com/smallocean/p/8525932.html:2018.3.7
發現自己留下的東西都可以當作時間膠囊,等未來某天翻看的時候,仿佛能看到那個時候的自己。
