主席樹詳解


主席樹是很簡(du)單(liu)的數據結構

題目給你一個序列,每次修改后算一個新的版本,詢問某個版本中某個值

我們先以Luogu P3919 【模板】可持久化數組(可持久化線段樹/平衡樹)作為模板講一下主席樹

主席樹(可持久化線段樹)

先學一下線段樹qaq

主席樹本名可持久化線段樹,也就是說,主席樹是基於線段樹發展而來的一種數據結構。其前綴"可持久化"意在給線段樹增加一些歷史點來維護歷史數據,使得我們能在較短時間內查詢歷史數據

不同於普通線段樹的是主席樹的左右子樹節點編號並不能夠用計算得到,所以我們需要記錄下來,但是對應的區間還是沒問題的。

我們注意到,對於修改操作,當前版本與它的前驅版本相比,只更改了一個節點的值,其他大多數節點的值沒有變化。

能不能重復利用,以達到節省空間的目的?

——分治?沒錯,如果只修改了左半邊,那么我們可以使用前驅版本的右半邊,反之同理。

於是,我們就可以用線段樹,進行修改操作時,只要當前節點的左(右)兒子沒有被修改,我們就可以使用前驅版本的那個節點。

那查找呢?每次保存版本i的根節點,利用線段樹的方法查找就好了。

代碼實現(代碼中有詳細注釋qaq):

#include <bits/stdc++.h>
#define N 1000005
using namespace std;
inline char nc(){
    static char buf[100000],*p1=buf,*p2=buf; 
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; 
}
inline int read()
{
    register int x=0,f=1;register char ch=nc();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=nc();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=nc();
    return x*f;
}
inline void write(register int x)
{
    if(!x)putchar('0');if(x<0)x=-x,putchar('-');
    static int sta[20];register int tot=0;
    while(x)sta[tot++]=x%10,x/=10;
    while(tot)putchar(sta[--tot]+48);
}
struct node{
	int rt[N],t[N<<5],ls[N<<5],rs[N<<5];
	int cnt;//尾節點,插入節點用
	inline int build(register int l,register int r)
	{
		int root=++cnt;
		if(l==r)
		{
			t[root]=read();//順帶讀入 
			return root;
		}
		int mid=l+r>>1;
		ls[root]=build(l,mid),rs[root]=build(mid+1,r);
		return root;
	}
	inline int update(register int pre,register int l,register int r,register int x,register int c)
	{
		int root=++cnt;
		if(l==r)
		{
			t[root]=c; //修改
			return root;
		}
		ls[root]=ls[pre],rs[root]=rs[pre];//先把子節點指向前驅結點以備復用
		int mid=l+r>>1;
		if(x<=mid)
			ls[root]=update(ls[pre],l,mid,x,c);
		else
			rs[root]=update(rs[pre],mid+1,r,x,c);
		return root;
	}
	inline void query(register int pre,register int l,register int r,register int x)
	{
		//普通的線段樹查詢
		if(l==r)
		{
			write(t[pre]),puts("");
			return;
		}
		int mid=l+r>>1;
		if(x<=mid)
			query(ls[pre],l,mid,x);
		else
			query(rs[pre],mid+1,r,x);
	}
}tr;
int main()
{
	tr.cnt=0;
	int n=read(),m=read();
	tr.build(1,n);
	tr.rt[0]=1;
	for(register int i=1;i<=m;++i)
	{
		int tic=read(),opt=read();
		if(opt==1)
		{
			int pos=read(),v=read();
			tr.rt[i]=tr.update(tr.rt[tic],1,n,pos,v);
		}
		else
		{
			int pos=read();
			tr.rt[i]=tr.rt[tic];
			tr.query(tr.rt[tic],1,n,pos);
		}
	}
	return 0;
 } 

還有一種問題是求靜態區間[l,r]中第k小的數

先給一個很暴力的做法:

先將區間進行排序(莫隊),再用平衡樹來求區間第k小

這樣的復雜度是 \(O(n \sqrt n \log n)\)

如果你有足夠的卡常技巧(A了挑戰),也許能卡過Luogu P3834 【模板】可持久化線段樹 1(主席樹)

50分莫隊+平衡樹做法

#pragma GCC optimize("O3")
#include <bits/stdc++.h>
#define N 500005
#define M 200005
using namespace std;
inline char nc(){
    static char buf[100000],*p1=buf,*p2=buf; 
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; 
}
inline int read()
{
    register int x=0,f=1;register char ch=nc();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=nc();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=nc();
    return x*f;
}
inline void write(register int x)
{
    if(!x)putchar('0');if(x<0)x=-x,putchar('-');
    static int sta[20];register int tot=0;
    while(x)sta[tot++]=x%10,x/=10;
    while(tot)putchar(sta[--tot]+48);
}
struct Splay{
    int v,fa,ch[2],sum,rec;
}tree[N];
int tot=0;
inline void update(register int x)
{
    tree[x].sum=tree[tree[x].ch[0]].sum+tree[tree[x].ch[1]].sum+tree[x].rec;
}
inline bool findd(register int x)
{
    return tree[tree[x].fa].ch[0]==x?0:1;
}
inline void connect(register int x,register int fa,register int son)
{
    tree[x].fa=fa;
    tree[fa].ch[son]=x;
} 
inline void rotate(register int x)
{
    int Y=tree[x].fa;
    int R=tree[Y].fa;
    int Yson=findd(x);
    int Rson=findd(Y);
    int B=tree[x].ch[Yson^1];
    connect(B,Y,Yson);
    connect(Y,x,Yson^1);
    connect(x,R,Rson);
    update(Y),update(x);
}
inline void splay(register int x,register int to)
{
    to=tree[to].fa;
    while(tree[x].fa!=to)
    {
        int y=tree[x].fa;
        if(tree[y].fa==to)
            rotate(x);
        else if(findd(x)==findd(y))
            rotate(y),rotate(x);
        else
            rotate(x),rotate(x);
    }	
}
inline int newpoint(register int v,register int fa)
{
    tree[++tot].fa=fa;
    tree[tot].v=v;
    tree[tot].sum=tree[tot].rec=1;
    return tot; 
}
inline void Insert(register int x)
{
    int now=tree[0].ch[1];
    if(tree[0].ch[1]==0)
    {
        newpoint(x,0);
        tree[0].ch[1]=tot;
    }
    else
    {
        while(19260817)
        {
            ++tree[now].sum;
            if(tree[now].v==x)
            {
                ++tree[now].rec;
                splay(now,tree[0].ch[1]);
                return;
            }
            int nxt=x<tree[now].v?0:1;
            if(!tree[now].ch[nxt])
            {
                int p=newpoint(x,now);
                tree[now].ch[nxt]=p;
                splay(p,tree[0].ch[1]);
                return;
            }
            now=tree[now].ch[nxt];
        }
    }
}
inline int find(register int v)
{
    int now=tree[0].ch[1];
    while(19260817)
    {
        if(tree[now].v==v)
        {
            splay(now,tree[0].ch[1]);
            return now;
        }
        int nxt=v<tree[now].v?0:1;
        if(!tree[now].ch[nxt])
            return 0;
        now=tree[now].ch[nxt];
    }
}
inline void delet(register int x)
{
    int pos=find(x);
    if(!pos)
        return;
    if(tree[pos].rec>1)
    {
        --tree[pos].rec;
        --tree[pos].sum;
    }
    else
    {
        if(!tree[pos].ch[0]&&!tree[pos].ch[1])
            tree[0].ch[1]=0;
        else if(!tree[pos].ch[0])
        {
            tree[0].ch[1]=tree[pos].ch[1];
            tree[tree[0].ch[1]].fa=0;
        }
        else
        {
            int left=tree[pos].ch[0];
            while(tree[left].ch[1])
                left=tree[left].ch[1];
            splay(left,tree[pos].ch[0]);
            connect(tree[pos].ch[1],left,1);
            connect(left,0,1);
            update(left);
        }
    }
}
inline int arank(register int x)
{
    int now=tree[0].ch[1];
    while(19260817)
    {
        int used=tree[now].sum-tree[tree[now].ch[1]].sum;
        if(x>tree[tree[now].ch[0]].sum&&x<=used)
        {
            splay(now,tree[0].ch[1]);
            return tree[now].v;
        }
        if(x<used)
            now=tree[now].ch[0];
        else
            x-=used,now=tree[now].ch[1];
    }
}
struct query{
    int l,r,id,bl,k;
}q[M];
int a[N],blocksize=0,ans[M];
inline bool cmp(register query a,register query b)
{
    return a.bl!=b.bl?a.l<b.l:((a.bl&1)?a.r<b.r:a.r>b.r);
}
int main()
{
    int n=read(),m=read();
    blocksize=sqrt(m);
    for(register int i=1;i<=n;++i)
        a[i]=read();
    for(register int i=1;i<=m;++i)
    {
        int l=read(),r=read(),k=read();
        q[i]=(query){l,r,i,l/blocksize,k};
    }
    sort(q+1,q+m+1,cmp);
    int l=1,r=0;
    for(register int i=1;i<=m;++i)
    {
        int ll=q[i].l,rr=q[i].r;
        while(ll<l)
            Insert(a[--l]);
        while(rr>r)
            Insert(a[++r]);
        while(ll>l)
            delet(a[l++]);
        while(rr<r)
            delet(a[r--]);
        ans[q[i].id]=arank(q[i].k);
    }
    for(register int i=1;i<=m;++i)
        write(ans[i]),puts("");
    return 0;
}

我們先考慮簡化的問題:我們要詢問整個區間內的第K小。這樣我們對值域建線段樹,每個節點記錄這個區間所包含的元素個數,建樹和查詢時的區間范圍用遞歸參數傳遞,然后用二叉查找樹的詢問方式即可:即如果左邊元素個數sum>=K,遞歸查找左子樹第K小,否則遞歸查找右子樹第K - sum小,直到返回葉子的值。

現在我們要回答對於區間[l, r]的第K小詢問。如果我們能夠得到一個插入原序列中[1, l - 1]元素的線段樹,和一顆插入了[1, r]元素的線段樹,由於線段樹是開在值域上,區間長度是一定的,所以結構也必然是完全相同的,我們可以直接對這兩顆線段樹進行相減,得到的是相當於插入了區間[l ,r]元素的線段樹。注意這里利用到的區間相減性質,實際上是用兩顆不同歷史版本的線段樹進行相減:一顆是插入到第l-1個元素的舊樹,一顆是插入到第r元素的新樹。

這樣相減之后得到的是相當於只插入了原序列中[l, r]元素的一顆記錄了區間數字個數的線段樹。直接對這顆線段樹按照BST的方式詢問,即可得到區間第k小。

這種做法是可行的,但是我們顯然不能每次插入一個元素,就從頭建立一顆全新的線段樹,否則內存開銷無法承受。事實上,每次插入一個新的元素時,我們不需要新建所有的節點,而是只新建增加的節點。也就是從根節點出發,先新建節點並復制原節點的值,然后進行修改即可。

這樣我們我們每到一個節點,只需要修改左兒子或者右兒子其一的信息,一直遞歸到葉子后結束,修改的節點數量就是樹高,也就是新建了不超過樹高個節點,內存開銷就可以承受了。

注意我們對root[0]也就是插入了零個元素的那顆樹,記錄的左右兒子指針都是0,這樣我們就可以用這一個節點表示一個任意結構的空樹而不需要顯式建樹。這是因為對於這個節點,不管你再怎么遞歸,都是指向這個節點本身,里面記錄的元素個數就是零。

#include <bits/stdc++.h>
#define N 200005
using namespace std;
inline char nc(){
    static char buf[100000],*p1=buf,*p2=buf; 
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; 
}
inline int read()
{
    register int x=0,f=1;register char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x*f;
}
inline void write(register int x)
{
	if(!x)putchar('0');
    static int sta[20];register int tot=0;
    while(x)sta[tot++]=x%10,x/=10;
    while(tot)putchar(sta[--tot]+48);
}
int n,q,m,cnt=0;
int a[N],b[N],T[N];
int sum[N<<5],ls[N<<5],rs[N<<5];
inline int build(register int l,register int r)
{
    int root=++cnt;
    sum[root]=0;
    int mid=l+r>>1;
    if(l<r)
    	ls[root]=build(l,mid),rs[root]=build(mid+1,r);
    return root;
}
inline int update(register int pre,register int l,register int r,register int x)
{
    int root=++cnt;
    ls[root]=ls[pre],rs[root]=rs[pre],sum[root]=sum[pre]+1;
    int mid=l+r>>1;
    if(l<r)
    {
    	if(x<=mid)
    	    ls[root]=update(ls[pre],l,mid,x);
    	else
        	rs[root]=update(rs[pre],mid+1,r,x);
    }
    return root;
}
inline int query(register int u,register int v,register int l,register int r,register int k)
{
	if(l>=r)
		return l;
	int x=sum[ls[v]]-sum[ls[u]];
	int mid=l+r>>1;
	if(x>=k)
		return query(ls[u],ls[v],l,mid,k);
	else
		return query(rs[u],rs[v],mid+1,r,k-x);
}
int main()
{
	n=read(),q=read();
	for(register int i=1;i<=n;++i)
		b[i]=a[i]=read();
	sort(b+1,b+n+1);
	m=unique(b+1,b+n+1)-b-1;
	T[0]=build(1,m);
	for(register int i=1;i<=n;++i)
	{
		int t=lower_bound(b+1,b+m+1,a[i])-b;
		T[i]=update(T[i-1],1,m,t);
	}
	while(q--)
	{
		int l=read(),r=read(),k=read();
		int t=query(T[l-1],T[r],1,m,k);
		write(b[t]),puts("");
	}
}

但是要注意,主席樹在不做額外處理時只能查詢靜態的區間k大(小)值。

接下來,我們就考慮動態區間k小值。如果我們要對區間進行修改的話,一個簡單的主席樹已經無法實現了。

如果對原來的節點直接修改的話,會造成不可名狀的運行錯誤(有興趣的同學可以結合上面插入代碼想一想為什么),

空間和時間也無法接受(我們需要把后面所有樹都更改一下),但我們在做樹套樹的時候,可以做類似的操作,那么主席樹是不是應該也套些什么呢?

主席樹上的點,儲存的都是在一段權值區間內的數據個數,我們必須要維護數據個數才可以通過相減得到一段區間的權值線段樹。

而現在有了修改,對於這個修改的維護,朴素的做法有2種:O(1)查詢,O(n)維護(掃一遍),和O(n)查詢(現場算)和O(1)維護。

這兩種做法都不是很憂,所以我們考慮利用快捷維護前綴和的樹狀數組解決這個問題,即所謂“樹狀數組套主席樹”

#include <bits/stdc++.h>
#define N 100005
#define M 40000005
using namespace std;
inline int read()
{
    register int x=0,f=1;register char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x*f;
}
inline void write(register int x)
{
    if(!x)putchar('0');if(x<0)x=-x,putchar('-');
    static int sta[25];int tot=0;
    while(x)sta[tot++]=x%10,x/=10;
    while(tot)putchar(sta[--tot]+48);
}
int a[N];  
int n,m;  
int root[N],ls[M],rs[M],c[M];  
int tot=0;  
int xx[40],yy[40];  
int v,d; 
inline int lowbit(register int x)
{
    return x&(-x);
}
inline void update(register int &now,register int l,register int r)
{
    if(now==0)
        now=++tot;
    c[now]+=d;
    if(l==r)
        return;
    int mid=l+r>>1;
    if(v<=mid)
        update(ls[now],l,mid);
    else
        update(rs[now],mid+1,r);
}
inline void change()
{
    int x=read(),b=read();
    d=-1,v=a[x];
    for(register int i=x;i<=n;i+=lowbit(i))
        update(root[i],0,1e9);
    d=1,v=b;
    for(register int i=x;i<=n;i+=lowbit(i))
        update(root[i],0,1e9);
    a[x]=b;
}
inline int query()
{
    int x=read(),y=read(),k=read();
    --x;
    x^=y^=x^=y;
    int t1=0,t2=0;
    for(register int i=x;i>=1;i-=lowbit(i))
        xx[++t1]=root[i];
    for(register int i=y;i>=1;i-=lowbit(i))
        yy[++t2]=root[i];
    int l=0,r=1e9;
    while(l<r)
    {
        int temp=0;
        for(register int i=1;i<=t1;++i)
            temp+=c[ls[xx[i]]];
        for(register int i=1;i<=t2;++i)
            temp-=c[ls[yy[i]]];
        if(k<=temp)
        {
            for(register int i=1;i<=t1;++i)
                xx[i]=ls[xx[i]];
            for(register int i=1;i<=t2;++i)
                yy[i]=ls[yy[i]];
            r=l+r>>1;   
        }
        else
        {
            for(register int i=1;i<=t1;++i)
                xx[i]=rs[xx[i]];
            for(register int i=1;i<=t2;++i)
                yy[i]=rs[yy[i]];
            k-=temp;
            l=(l+r>>1)+1;
        }
    }
    return l;
}
int main()
{
    n=read(),m=read();
    for(register int i=1;i<=n;++i)
    {
        v=read();
        a[i]=v,d=1;
        for(register int j=i;j<=n;j+=lowbit(j))
            update(root[j],0,1e9);
    }
    while(m--)
    {
        char ch=getchar();
        while(ch!='C'&&ch!='Q')
            ch=getchar();
        if(ch=='Q')
            write(query()),puts("");
        else
            change();
    }
    return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM