[學習筆記]樹套樹


引言

樹套樹,顧名思義,就是要將兩種或多種樹形數據結構結合起來,解決一些單獨無法解決的問題。

如果說要解決區間上的問題,如最大值,區間修改等,肯定會想到線段樹

但是線段樹不能查詢第k大,不能查詢一個數在區間的排名,自然也不能查詢前驅和后繼。

平衡樹可以解決查詢排名、前驅、后繼等問題,但其不能限定區間。

文藝平衡樹中有操作可以把區間鎖定在一個結點的子樹,問題是只能通過翻轉左右子樹,來實現區間翻轉。

既然單獨無法解決這個問題,那就將兩種樹形數據結構結合起來。

原理

很多人都對樹套樹望而生畏,包括我。。。

以前只知道通過樹套樹可以解決的問題,但沒有敲過

經常聽到隊友說幾個線段樹,再用一個主席樹維護什么什么的,但其實原理不難,只要懂樹套樹中的這兩種樹。

舉個例子:

如圖這是個線段樹

假設這個序列是: 5 2 3 4 5 7 8 9 3 1 (隨便寫的)

現在我要查2-7區間中第5個數即5在這個區間排第幾小,

[3-7]區間即:[3],[4-5],[6-7]

第幾小即計算有多少個比它小,然后加一

[3]:1個

[4-5]:1個

[6-7]:0個

所以他是第3小的。

將每個子區間得到的答案求和利用的是線段樹

而中間每個區間查詢有多少比它小利用的是平衡樹Splay

線段樹的每個結點建立一個Splay

有人會懷疑空間復雜度不夠,如果把Splay封裝,每個Splay都是\(N\)的大小必然不夠,我們不需要事先開辟那么多空間來建Splay

代碼中是:(開局一個root,然后記錄每個線段樹結點的root就行了)

void build(int p,int l,int r){
    t[p].l = l,t[p].r = r;
    //線段樹每個結點建立一個splay
    sp.ins(t[p].rt,-inf);
    sp.ins(t[p].rt,inf);
    for(int i = l;i <= r;++i){
        sp.ins(t[p].rt,arr[i]);
    }
    if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
    int mid = (l+r) >> 1;
    build(p<<1,l,mid);
    build(p<<1|1,mid + 1,r);
    pushUp(p);
}

這里有個問題,root會發生變化,所以線段樹結點中定義的root並不是一成不變的,這需要用到引用,即傳地址

還有就是要插入兩個無窮大結點,來解決不存在的情況。

應用-模板題

  1. 查詢k在區間內的排名
  2. 查詢區間內排名為k的值
  3. 修改某一位值上的數值
  4. 查詢k在區間內的前驅(前驅定義為嚴格小於x,且最大的數,若不存在輸出-2147483647)
  5. 查詢k在區間內的后繼(后繼定義為嚴格大於x,且最小的數,若不存在輸出2147483647)

先復制上封裝好的Splay

struct Splay {
     int get(int x) {return s[s[x].fa].ch[1] == x;}

     void Clear(int x) {
         s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
     }

     void maintain(int x){
         s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
     }

     void Rorate(int x){
         int y = s[x].fa, z = s[y].fa, chk = get(x);

         s[y].ch[chk] = s[x].ch[chk ^ 1];
         s[s[x].ch[chk ^ 1]].fa = y;

         s[y].fa = x;
         s[x].ch[chk ^ 1] = y;

         s[x].fa =z;
         if(z) s[z].ch[s[z].ch[1] == y] = x;

         maintain(y);
         maintain(x);
     }

     void splay(int &rt,int x,int y){
         for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
             if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
         }
         if(y==0) rt = x;
     }

     void ins(int &root ,int val){
         if(!root) {
             root = ++tot;
             s[root].val = val;
             s[root].cnt++;
             maintain(root);
             return ;
         }

         int f = 0, x = root;
         while(true){
             if(s[x].val == val){
                 s[x].cnt ++;
                 maintain(x);
                 maintain(f);
                 splay(root,x,0);
                 return;
             }

             f = x;
             x = s[x].ch[s[x].val < val];
             if(!x) {
                 s[++tot].val = val;
                 s[tot].cnt = 1;
                 s[tot].fa = f;
                 s[f].ch[s[f].val < val] = tot;
                 maintain(tot);
                 maintain(f);
                 splay(root,tot,0);
                 return ;
             }
         }
     }

    inline int Find(int &rt,int k) {
        int res = 0,now = rt;
        while(true) {
            if(k<s[now].val) {
                now = s[now].ch[0];
            }else {
                //否則加上右子樹的個數
                res += s[s[now].ch[0]].sz;
                //中序遍歷,如果找到這個節點返回res+1
                if(k == s[now].val) {
                    splay(rt,now,0);
                    return res + 1;
                }
                res += s[now].cnt;
                now = s[now].ch[1];
            }
        }
    }

     int getPre(int rt){
         int now = s[rt].ch[0];
         while (s[now].ch[1]) now = s[now].ch[1];
         return now;
     }

     int getNxt(int rt){
         int now = s[rt].ch[1];
         while (s[now].ch[0]) now = s[now].ch[0];
         return now;
     }
     
    inline void del(int &rt,int k){
       Find(rt,k);//先讓該點成為根節點
        if(s[rt].cnt > 1) {//如果大於1,不需要刪除節點
            s[rt].cnt--;
            maintain(rt);
            return;
        }
        //如果只有一個點
        if(!s[rt].ch[0] && !s[rt].ch[1]){
            Clear(rt);
            rt = 0;
            return;
        }
        //沒有左兒子,讓右兒子成為根節點
        if(!s[rt].ch[0]){
            int tmp = rt;
            rt = s[rt].ch[1];
            s[rt].fa=0;
            Clear(tmp);
            return;
        }
        //沒有右兒子,讓左兒子成為根節點
        if(!s[rt].ch[1]){
            int tmp = rt;
            rt = s[rt].ch[0];
            s[rt].fa = 0;
            Clear(tmp);
            return;
        }
        //有左右兒子,讓前驅成為根節點
        int x = getPre(rt) , now = rt;
        splay(rt,x,0);
        s[s[now].ch[1]].fa = x;
        s[x].ch[1] = s[now].ch[1];
        Clear(now);
        maintain(rt);
    }
}sp;
  • 問題1之前提到了,就是在Splay中插入這個結點,然后返回這個結點的左兒子的Size就行,記得減去無窮大的那個點。
int query_order(int p,int l,int r,int val){
    //查詢順序,就是查有多少個比他小
    if(l <= t[p].l && t[p].r <= r){
       sp.ins(t[p].rt,val);
       int res = s[s[t[p].rt].ch[0]].sz-1;
       sp.del(t[p].rt,val);
       return res;
    }
    int mid = (t[p].l + t[p].r) >> 1,res = 0;
    if(l <= mid) res += query_order(p << 1,l,r,val);
    if(mid < r) res += query_order(p << 1|1,l,r,val);
    return res;
}
  • 問題2求排名k的值,這需要用到二分,二分check函數就是問題1的query_order,在區間權值范圍內二分,權值越大排名越大,就是在單調遞增區間中查詢小於k的數的最大值(因為有一個無窮小結點,所以不能小於等於)。二分模板也很明顯:
int query_number(int L,int R,int val){
    int l = 1,r = getMax(1,L,R) ,mid,tmp;
    while(l < r){
        mid = (l + r + 1)>>1;
        tmp = query_order(1,L,R,mid);
        if(tmp < val){
            l = mid;
        }else{
            r = mid - 1;
        }
    }
    return l;
}
  • 問題3是修改,這個不難,這個點所在的所有線段樹結點都要刪除該點在Splay樹上的結點,然后加入新值。
void modify(int p,int pos,int val){
     sp.del(t[p].rt,arr[pos]);
     sp.ins(t[p].rt,val);
     if(t[p].l == t[p].r){
         t[p].mx = val;
         t[p].mn = val;
         arr[pos] = val;
         return;
     }
     int mid = (t[p].l + t[p].r) >> 1;
     if(pos <= mid) modify(p << 1,pos,val);
     if(pos > mid) modify(p << 1 | 1,pos,val);
     pushUp(p);
}
  • 問題4,查詢前驅,即查詢每個線段樹區間最大的比該數小的數,最后取個最大值。5同理
int query_Pre(int p,int l,int r,int val){
    if(l <= t[p].l && r >= t[p].r){
        sp.ins(t[p].rt,val);
        int res = s[sp.getPre(t[p].rt)].val;
        sp.del(t[p].rt,val);
        return res;
    }
    int res = -inf,mid = (t[p].l + t[p].r) >> 1;
    if(l <= mid)  res = max(res,query_Pre(p << 1,l,r,val));
    if(r > mid)  res = max(res,query_Pre(p << 1|1,l,r,val));
    return res;
}

int query_Nxt(int p,int l,int r,int val){
    if(l <= t[p].l && r >= t[p].r){
        sp.ins(t[p].rt,val);
        int res = s[sp.getNxt(t[p].rt)].val;
        sp.del(t[p].rt,val);
        return res;
    }
    int res = inf,mid = (t[p].l + t[p].r) >> 1;
    if(l <= mid)  res = min(res,query_Nxt(p << 1,l,r,val));
    if(r > mid)  res = min(res,query_Nxt(p << 1|1,l,r,val));
    return res;
}
  • 中途為了優化二分(也沒什么用),還加了線段樹查詢最大值和最小值的
int getMax(int p,int l,int r){
    if(l <= t[p].l && t[p].r <=r) return t[p].mx;
    int mid = (t[p].l + t[p].r) >> 1,res = -inf;
    if(l <= mid) res = max(res,getMax(p << 1,l,r));
    if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
    return res;
}

int getMin(int p,int l,int r){
    if(l <= t[p].l && t[p].r <= r) return t[p].mx;
    int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
    if(l <= mid) res = min(res,getMin(p << 1,l,r));
    if(mid < r) res = min(res,getMin(p << 1|1,l,r));
    return res;
}

完整代碼

#pragma GCC optimize(2)
#pragma GCC optimize(3,"Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
#define ll long long

const int N = 1e7+7;
const int inf = 2147483647;
int tot;//節點個數
struct node {
    int fa;//父親節點
    int ch[2];//子節點
    int val;//權值
    int sz;//子樹大小
    int cnt;
}s[N];
struct Tree{
    int rt,l,r,mx,mn;
}t[N];
int arr[N];
struct Splay {
     int get(int x) {return s[s[x].fa].ch[1] == x;}

     void Clear(int x) {
         s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
     }

     void maintain(int x){
         s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
     }

     void Rorate(int x){
         int y = s[x].fa, z = s[y].fa, chk = get(x);

         s[y].ch[chk] = s[x].ch[chk ^ 1];
         s[s[x].ch[chk ^ 1]].fa = y;

         s[y].fa = x;
         s[x].ch[chk ^ 1] = y;

         s[x].fa =z;
         if(z) s[z].ch[s[z].ch[1] == y] = x;

         maintain(y);
         maintain(x);
     }

     void splay(int &rt,int x,int y){
         for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
             if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
         }
         if(y==0) rt = x;
     }

     void ins(int &root ,int val){
         if(!root) {
             root = ++tot;
             s[root].val = val;
             s[root].cnt++;
             maintain(root);
             return ;
         }

         int f = 0, x = root;
         while(true){
             if(s[x].val == val){
                 s[x].cnt ++;
                 maintain(x);
                 maintain(f);
                 splay(root,x,0);
                 return;
             }

             f = x;
             x = s[x].ch[s[x].val < val];
             if(!x) {
                 s[++tot].val = val;
                 s[tot].cnt = 1;
                 s[tot].fa = f;
                 s[f].ch[s[f].val < val] = tot;
                 maintain(tot);
                 maintain(f);
                 splay(root,tot,0);
                 return ;
             }
         }
     }

    inline int Find(int &rt,int k) {
        int res = 0,now = rt;
        while(true) {
            if(k<s[now].val) {
                now = s[now].ch[0];
            }else {
                //否則加上右子樹的個數
                res += s[s[now].ch[0]].sz;
                //中序遍歷,如果找到這個節點返回res+1
                if(k == s[now].val) {
                    splay(rt,now,0);
                    return res + 1;
                }
                res += s[now].cnt;
                now = s[now].ch[1];
            }
        }
    }

     int getPre(int rt){
         int now = s[rt].ch[0];
         while (s[now].ch[1]) now = s[now].ch[1];
         return now;
     }

     int getNxt(int rt){
         int now = s[rt].ch[1];
         while (s[now].ch[0]) now = s[now].ch[0];
         return now;
     }

    inline void del(int &rt,int k){
       Find(rt,k);//先讓該點成為根節點
        if(s[rt].cnt > 1) {//如果大於1,不需要刪除節點
            s[rt].cnt--;
            maintain(rt);
            return;
        }
        //如果只有一個點
        if(!s[rt].ch[0] && !s[rt].ch[1]){
            Clear(rt);
            rt = 0;
            return;
        }
        //沒有左兒子,讓右兒子成為根節點
        if(!s[rt].ch[0]){
            int tmp = rt;
            rt = s[rt].ch[1];
            s[rt].fa=0;
            Clear(tmp);
            return;
        }
        //沒有右兒子,讓左兒子成為根節點
        if(!s[rt].ch[1]){
            int tmp = rt;
            rt = s[rt].ch[0];
            s[rt].fa = 0;
            Clear(tmp);
            return;
        }
        //有左右兒子,讓前驅成為根節點
        int x = getPre(rt) , now = rt;
        splay(rt,x,0);
        s[s[now].ch[1]].fa = x;
        s[x].ch[1] = s[now].ch[1];
        Clear(now);
        maintain(rt);
    }
}sp;

void pushUp(int x){
    t[x].mx = max(t[x<<1].mx,t[x<<1|1].mx);
    t[x].mn = min(t[x<<1].mn,t[x<<1|1].mn);
}

void build(int p,int l,int r){
    t[p].l = l,t[p].r = r;
    //線段樹每個結點建立一個splay
    sp.ins(t[p].rt,-inf);
    sp.ins(t[p].rt,inf);
    for(int i = l;i <= r;++i){
        sp.ins(t[p].rt,arr[i]);
    }
    if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
    int mid = (l+r) >> 1;
    build(p<<1,l,mid);
    build(p<<1|1,mid + 1,r);
    pushUp(p);
}

int getMax(int p,int l,int r){
    if(l <= t[p].l && t[p].r <=r) return t[p].mx;
    int mid = (t[p].l + t[p].r) >> 1,res = -inf;
    if(l <= mid) res = max(res,getMax(p << 1,l,r));
    if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
    return res;
}

int getMin(int p,int l,int r){
    if(l <= t[p].l && t[p].r <= r) return t[p].mx;
    int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
    if(l <= mid) res = min(res,getMin(p << 1,l,r));
    if(mid < r) res = min(res,getMin(p << 1|1,l,r));
    return res;
}

int query_order(int p,int l,int r,int val){
    //查詢順序,就是查有多少個比他小
    if(l <= t[p].l && t[p].r <= r){
       sp.ins(t[p].rt,val);
       int res = s[s[t[p].rt].ch[0]].sz-1;
       sp.del(t[p].rt,val);
       return res;
    }
    int mid = (t[p].l + t[p].r) >> 1,res = 0;
    if(l <= mid) res += query_order(p << 1,l,r,val);
    if(mid < r) res += query_order(p << 1|1,l,r,val);
    return res;
}

void modify(int p,int pos,int val){
     sp.del(t[p].rt,arr[pos]);
     sp.ins(t[p].rt,val);
     if(t[p].l == t[p].r){
         t[p].mx = val;
         t[p].mn = val;
         arr[pos] = val;
         return;
     }
     int mid = (t[p].l + t[p].r) >> 1;
     if(pos <= mid) modify(p << 1,pos,val);
     if(pos > mid) modify(p << 1 | 1,pos,val);
     pushUp(p);
}

int query_Pre(int p,int l,int r,int val){
    if(l <= t[p].l && r >= t[p].r){
        sp.ins(t[p].rt,val);
        int res = s[sp.getPre(t[p].rt)].val;
        sp.del(t[p].rt,val);
        return res;
    }
    int res = -inf,mid = (t[p].l + t[p].r) >> 1;
    if(l <= mid)  res = max(res,query_Pre(p << 1,l,r,val));
    if(r > mid)  res = max(res,query_Pre(p << 1|1,l,r,val));
    return res;
}

int query_Nxt(int p,int l,int r,int val){
    if(l <= t[p].l && r >= t[p].r){
        sp.ins(t[p].rt,val);
        int res = s[sp.getNxt(t[p].rt)].val;
        sp.del(t[p].rt,val);
        return res;
    }
    int res = inf,mid = (t[p].l + t[p].r) >> 1;
    if(l <= mid)  res = min(res,query_Nxt(p << 1,l,r,val));
    if(r > mid)  res = min(res,query_Nxt(p << 1|1,l,r,val));
    return res;
}

int query_number(int L,int R,int val){
    int l = 1,r = getMax(1,L,R) ,mid,tmp;
    while(l < r){
        mid = (l + r + 1)>>1;
        tmp = query_order(1,L,R,mid);
        if(tmp < val){
            l = mid;
        }else{
            r = mid - 1;
        }
    }
    return l;
}

int main(){
    int n,q,op,l,r,pos;
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;++i) scanf("%d",&arr[i]);
    build(1,1,n);
    while(q--){
        scanf("%d",&op);
        if(op == 1){
            scanf("%d%d%d",&l,&r,&pos);
            printf("%d\n",query_order(1,l,r,pos)+1);
        }else if(op == 2){
            scanf("%d%d%d",&l,&r,&pos);
            printf("%d\n",query_number(l,r,pos));
        }else if(op == 3){
            scanf("%d%d",&l,&pos);
            modify(1,l,pos);
        }else if(op == 4){
            scanf("%d%d%d",&l,&r,&pos);
            printf("%d\n",query_Pre(1,l,r,pos));
        }else if(op == 5){
            scanf("%d%d%d",&l,&r,&pos);
            printf("%d\n",query_Nxt(1,l,r,pos));
        }
    }
    return 0;
}

代碼不加O2優化會超時,如果要優化的話,可以加個輸入輸出掛。

后記

博客兩周年快樂。

這是第一篇博客https://www.cnblogs.com/smallocean/p/8525932.html:2018.3.7

發現自己留下的東西都可以當作時間膠囊,等未來某天翻看的時候,仿佛能看到那個時候的自己。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM