主席树真是神仙操作啊……搞了好久才弄懂一点点QAQ
参考文章:https://www.cnblogs.com/zyf0163/p/4749042.html
https://blog.csdn.net/creatorx/article/details/75446472
https://blog.csdn.net/jerans/article/details/75807666
http://www.cnblogs.com/zcysky/p/6832876.html
ps:本文章中的题目我都写过题解了,可以自己去找
1.前言
据说主席树这个名字的由来呢,是因为创始人的名字缩写hjt与某位相同,然后他因为不会划分树于是自创了这一个数据结构。好强啊orz
主席树能实现什么操作呢?最经典的就是查询区间第k小了,其他的还有诸如树上路径第k小啦,带修改第k小啦之类的。以静态区间第k小为例
2.定义
先贴一下某神犇对主席树的理解:所谓主席树呢,就是对原来的数列[1..n]的每一个前缀[1..i](1≤i≤n)建立一棵线段树,线段树的每一个节点存某个前缀[1..i]中属于区间[L..R]的数一共有多少个(比如根节点是[1..n],一共i个数,sum[root] = i;根节点的左儿子是[1..(L+R)/2],若不大于(L+R)/2的数有x个,那么sum[root.left] = x)。若要查找[i..j]中第k大数时,设某结点x,那么x.sum[j] - x.sum[i - 1]就是[i..j]中在结点x内的数字总数。而对每一个前缀都建一棵树,会MLE,观察到每个[1..i]和[1..i-1]只有一条路是不一样的,那么其他的结点只要用回前一棵树的结点即可,时空复杂度为O(nlogn)。
然而没有什么用,因为感觉根本没看懂
然后来说说我自己的理解吧。如何求出一个区间内第k小呢?直接sort当然可以,但是复杂度爆表。于是我们可以换一个思路,能否将$[l,r]$之间出现过的数都建成线段树呢?设节点为$p$,区间为$[l,r]$,左儿子是$[l,mid]$,右儿子是$[mid+1,r]$
要查找第k大的话,先看左儿子里有多少个数(表示小于等于$mid$的数的个数),如果大于$k$,进左子树找,否则令$k-=左儿子数的个数$,进右子树找
先来考虑一个序列:3,2,1,4
建完树之后是这样的
然后要查第2大,一下子就能发现是2了
(上面画的可能不是很严谨,大家将就下)
但我们不可能对每一个区间都建一棵树,那样的话空间复杂度绝对爆炸
然后可以转化一下思路:前缀和
区间$[l,r]$中小于等于$mid$的数的个数,可以转换为$[1,r]$中小于等于$mid$的数的个数减去$[1,l-1]$中小于等于$mid$的数的个数
于是我们只要对每一个前缀建一棵树即可
然后空间复杂度还是爆炸
然而我们又发现,区间$[1,l-1]$的树和区间$[1,l]$的树最多只会有$log n$个节点不同(因为每次新插入一个节点最多只会更新$log n$个节点),有许多空间是可以重复利用的
只要能将这些空间重复利用起来,就可以解决空间的问题了
还是上面那个序列:3,2,1,4
一开始先建一棵空树,然后一个个把每一个节点加进去
如果要看图的话可以点这里
这个时候有人就要问了,万一序列的数字特别大呢?
当然是离散化
将这些所有值离散一下就行了,可以保证所有数在$1~n$之间
然而感觉讲太多也没啥用……上代码好了,有详细的注释
以区间第k小为例 洛谷p3834

1 //minamoto 2 #include<bits/stdc++.h> 3 #define N 200005 4 using namespace std; 5 inline int read(){ 6 #define num ch-'0' 7 char ch;bool flag=0;int res; 8 while(!isdigit(ch=getchar())) 9 (ch=='-')&&(flag=true); 10 for(res=num;isdigit(ch=getchar());res=res*10+num); 11 (flag)&&(res=-res); 12 #undef num 13 return res; 14 } 15 int sum[N<<5],L[N<<5],R[N<<5]; 16 int a[N],b[N],t[N]; 17 int n,q,m,cnt=0; 18 int build(int l,int r){ 19 int rt=++cnt; 20 //建树 21 sum[rt]=0; 22 if(l<r){ 23 int mid=(l+r)>>1; 24 L[rt]=build(l,mid); 25 R[rt]=build(mid+1,r); 26 } 27 return rt; 28 } 29 int update(int last,int l,int r,int x){ 30 int rt=++cnt; 31 L[rt]=L[last],R[rt]=R[last],sum[rt]=sum[last]+1; 32 //先继承上一次的信息 33 //L是左节点,R是右节点,sum是节点内数的个数 34 if(l<r){ 35 int mid=(l+r)>>1; 36 if(x<=mid) L[rt]=update(L[last],l,mid,x); 37 else R[rt]=update(R[last],mid+1,r,x); 38 //如果有需要更新的信息,更新 39 //可以发现每一次更新的节点最多只有log n个 40 } 41 return rt; 42 } 43 int query(int u,int v,int l,int r,int k){ 44 if(l>=r) return l; 45 int x=sum[L[v]]-sum[L[u]]; 46 //查询操作 47 int mid=(l+r)>>1; 48 if(x>=k) return query(L[u],L[v],l,mid,k); 49 else return query(R[u],R[v],mid+1,r,k-x); 50 //如果左节点个数大于等于k,进左子树找第k小 51 //否则进右子树 52 } 53 int main(){ 54 //freopen("testdata.in","r",stdin); 55 n=read(),q=read(); 56 for(int i=1;i<=n;++i) 57 b[i]=a[i]=read(); 58 sort(b+1,b+1+n); 59 m=unique(b+1,b+1+n)-b-1; 60 t[0]=build(1,m); 61 //先建一棵空树 62 for(int i=1;i<=n;++i){ 63 int k=lower_bound(b+1,b+1+m,a[i])-b; 64 //离散 65 t[i]=update(t[i-1],1,m,k); 66 //然后每次在上一次的基础上建树 67 } 68 while(q--){ 69 int x,y,z; 70 x=read(),y=read(),z=read(); 71 int k=query(t[x-1],t[y],1,m,z); 72 printf("%d\n",b[k]); 73 } 74 return 0; 75 }
如果熟练了之后,可以发现其实第一步的建树过程是可以省略的,直接每一步加节点就行了

1 //minamoto 2 #include<bits/stdc++.h> 3 #define N 200005 4 using namespace std; 5 inline int read(){ 6 #define num ch-'0' 7 char ch;bool flag=0;int res; 8 while(!isdigit(ch=getchar())) 9 (ch=='-')&&(flag=true); 10 for(res=num;isdigit(ch=getchar());res=res*10+num); 11 (flag)&&(res=-res); 12 #undef num 13 return res; 14 } 15 int sum[N<<5],L[N<<5],R[N<<5]; 16 int a[N],b[N],t[N]; 17 int n,q,m,cnt=0; 18 void update(int last,int &now,int l,int r,int x){ 19 //注意这里开的是引用 20 if(!now) now=++cnt; 21 sum[now]=sum[last]+1; 22 if(l==r) return; 23 int mid=(l+r)>>1; 24 if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x); 25 else L[now]=L[last],update(R[last],R[now],mid+1,r,x); 26 } 27 int query(int u,int v,int l,int r,int k){ 28 if(l>=r) return l; 29 int x=sum[L[v]]-sum[L[u]]; 30 int mid=(l+r)>>1; 31 if(x>=k) return query(L[u],L[v],l,mid,k); 32 else return query(R[u],R[v],mid+1,r,k-x); 33 } 34 int main(){ 35 //freopen("testdata.in","r",stdin); 36 n=read(),q=read(); 37 for(int i=1;i<=n;++i) 38 b[i]=a[i]=read(); 39 sort(b+1,b+1+n); 40 m=unique(b+1,b+1+n)-b-1; 41 for(int i=1;i<=n;++i){ 42 int k=lower_bound(b+1,b+1+m,a[i])-b; 43 update(t[i-1],t[i],1,m,k); 44 //省略建树过程,直接加入节点 45 } 46 while(q--){ 47 int x,y,z; 48 x=read(),y=read(),z=read(); 49 int k=query(t[x-1],t[y],1,m,z); 50 printf("%d\n",b[k]); 51 } 52 return 0; 53 }
还有一道板子题洛谷SP3946 poj2104 K-th Number

1 //minamoto 2 #include<bits/stdc++.h> 3 #define N 100005 4 using namespace std; 5 inline int read(){ 6 #define num ch-'0' 7 char ch;bool flag=0;int res; 8 while(!isdigit(ch=getchar())) 9 (ch=='-')&&(flag=true); 10 for(res=num;isdigit(ch=getchar());res=res*10+num); 11 (flag)&&(res=-res); 12 #undef num 13 return res; 14 } 15 int sum[N<<5],L[N<<5],R[N<<5]; 16 int a[N],b[N],rt[N]; 17 int n,q,m,cnt=0; 18 void update(int last,int &now,int l,int r,int x){ 19 sum[now=++cnt]=sum[last]+1; 20 if(l==r) return; 21 int mid=(l+r)>>1; 22 if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x); 23 else L[now]=L[last],update(R[last],R[now],mid+1,r,x); 24 } 25 int query(int u,int v,int l,int r,int k){ 26 if(l>=r) return l; 27 int x=sum[L[v]]-sum[L[u]]; 28 int mid=(l+r)>>1; 29 if(x>=k) return query(L[u],L[v],l,mid,k); 30 else return query(R[u],R[v],mid+1,r,k-x); 31 } 32 int main(){ 33 //freopen("testdata.in","r",stdin); 34 n=read(),q=read(); 35 for(int i=1;i<=n;++i) 36 b[i]=a[i]=read(); 37 sort(b+1,b+1+n); 38 m=unique(b+1,b+1+n)-b-1; 39 for(int i=1;i<=n;++i){ 40 int k=lower_bound(b+1,b+1+m,a[i])-b; 41 update(rt[i-1],rt[i],1,m,k); 42 } 43 while(q--){ 44 int x,y,z; 45 x=read(),y=read(),z=read(); 46 int k=query(rt[x-1],rt[y],1,m,z); 47 printf("%d\n",b[k]); 48 } 49 return 0; 50 }
还有一道题,也是主席树的一般应用 洛谷P3567 [POI2014]KUR-Couriers
出现次数可以转化为左右节点的大小,如果符合条件就递归

1 //minamoto 2 #include<bits/stdc++.h> 3 #define N 500005 4 using namespace std; 5 inline int read(){ 6 #define num ch-'0' 7 char ch;bool flag=0;int res; 8 while(!isdigit(ch=getchar())) 9 (ch=='-')&&(flag=true); 10 for(res=num;isdigit(ch=getchar());res=res*10+num); 11 (flag)&&(res=-res); 12 #undef num 13 return res; 14 } 15 int sum[N*20],L[N*20],R[N*20],t[N]; 16 int n,q,cnt=0; 17 void update(int last,int &now,int l,int r,int x){ 18 if(!now) now=++cnt; 19 sum[now]=sum[last]+1; 20 if(l==r) return; 21 int mid=(l+r)>>1; 22 if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x); 23 else L[now]=L[last],update(R[last],R[now],mid+1,r,x); 24 } 25 int query(int u,int v,int l,int r,int k){ 26 if(l==r) return l; 27 int x=sum[L[v]]-sum[L[u]],y=sum[R[v]]-sum[R[u]]; 28 int mid=(l+r)>>1; 29 if(x*2>k) return query(L[u],L[v],l,mid,k); 30 if(y*2>k) return query(R[u],R[v],mid+1,r,k); 31 return 0; 32 } 33 int main(){ 34 //freopen("testdata.in","r",stdin); 35 n=read(),q=read(); 36 for(int i=1;i<=n;++i){ 37 int x=read(); 38 update(t[i-1],t[i],1,n,x); 39 } 40 while(q--){ 41 int x,y; 42 x=read(),y=read(); 43 int k=query(t[x-1],t[y],1,n,y-x+1); 44 printf("%d\n",k); 45 } 46 return 0; 47 }
然后区间静态第k大就解决了~\(≧▽≦)/~啦啦啦
树上路径
有些题目会给你一棵树,问你树上两点间路径上的第k大
怎么解决呢?
可以发现,这个东西是可以进行差分的
比如说,$u$到$v$路径上的权值和,可以变成$sum[u]+sum[v]-sum[lca]-sum[lca_fa]$
然后套到主席树上,就是小于某个数的个数,同样也可以差分出来表示
但问题是主席树怎么建呢?
我们发现,因为要求lca,我们可以在树剖dfs的时候顺便加点
具体来说,就是用$fa[i]$的信息更新$i$点的信息
以bzoj2588 洛谷p2633. count on a tree为例

1 //minamoto 2 #include<bits/stdc++.h> 3 #define N 100005 4 #define M 2000005 5 using namespace std; 6 inline int read(){ 7 #define num ch-'0' 8 char ch;bool flag=0;int res; 9 while(!isdigit(ch=getchar())) 10 (ch=='-')&&(flag=true); 11 for(res=num;isdigit(ch=getchar());res=res*10+num); 12 (flag)&&(res=-res); 13 #undef num 14 return res; 15 } 16 int sum[M],L[M],R[M]; 17 int a[N],b[N],rt[N]; 18 int fa[N],sz[N],d[N],ver[N<<1],Next[N<<1],head[N],son[N],top[N]; 19 int n,q,m,cnt=0,tot=0,ans=0; 20 void update(int last,int &now,int l,int r,int x){ 21 sum[now=++cnt]=sum[last]+1; 22 if(l==r) return; 23 int mid=(l+r)>>1; 24 if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x); 25 else L[now]=L[last],update(R[last],R[now],mid+1,r,x); 26 } 27 inline void add(int u,int v){ 28 ver[++tot]=v,Next[tot]=head[u],head[u]=tot; 29 ver[++tot]=u,Next[tot]=head[v],head[v]=tot; 30 } 31 void dfs(int u){ 32 sz[u]=1,d[u]=d[fa[u]]+1; 33 update(rt[fa[u]],rt[u],1,m,a[u]); 34 for(int i=head[u];i;i=Next[i]){ 35 int v=ver[i]; 36 if(v==fa[u]) continue; 37 fa[v]=u,dfs(v); 38 sz[u]+=sz[v]; 39 if(!son[u]||sz[v]>sz[son[u]]) son[u]=v; 40 } 41 } 42 void dfs(int u,int tp){ 43 top[u]=tp; 44 if(!son[u]) return; 45 dfs(son[u],tp); 46 for(int i=head[u];i;i=Next[i]){ 47 int v=ver[i]; 48 if(v==son[u]||v==fa[u]) continue; 49 dfs(v,v); 50 } 51 } 52 int LCA(int x,int y){ 53 while(top[x]!=top[y]) 54 d[top[x]]>=d[top[y]]?x=fa[top[x]]:y=fa[top[y]]; 55 return d[x]>=d[y]?y:x; 56 } 57 int query(int ql,int qr,int lca,int lca_fa,int l,int r,int k){ 58 if(l>=r) return l; 59 int x=sum[L[ql]]+sum[L[qr]]-sum[L[lca]]-sum[L[lca_fa]]; 60 int mid=(l+r)>>1; 61 if(x>=k) return query(L[ql],L[qr],L[lca],L[lca_fa],l,mid,k); 62 else return query(R[ql],R[qr],R[lca],R[lca_fa],mid+1,r,k-x); 63 } 64 int main(){ 65 //freopen("testdata.in","r",stdin); 66 n=read(),q=read(); 67 for(int i=1;i<=n;++i) 68 b[i]=a[i]=read(); 69 sort(b+1,b+1+n); 70 m=unique(b+1,b+1+n)-b-1; 71 for(int i=1;i<=n;++i) 72 a[i]=lower_bound(b+1,b+1+m,a[i])-b; 73 for(int i=1;i<n;++i){ 74 int u=read(),v=read(); 75 add(u,v); 76 } 77 dfs(1),dfs(1,1); 78 while(q--){ 79 int x,y,z,lca; 80 x=read(),y=read(),z=read(); 81 x^=ans,lca=LCA(x,y); 82 ans=b[query(rt[x],rt[y],rt[lca],rt[fa[lca]],1,m,z)]; 83 printf("%d\n",ans); 84 } 85 return 0; 86 }
还有一题[bzoj3123][洛谷P3302] [SDOI2013]森林
路经查询就是主席树维护,而连接两棵树就是用启发式合并

1 //minamoto 2 #include<bits/stdc++.h> 3 using namespace std; 4 inline int read(){ 5 #define num ch-'0' 6 char ch;bool flag=0;int res; 7 while(!isdigit(ch=getchar())) 8 (ch=='-')&&(flag=true); 9 for(res=num;isdigit(ch=getchar());res=res*10+num); 10 (flag)&&(res=-res); 11 #undef num 12 return res; 13 } 14 const int N=80005,M=N*200; 15 int ver[N<<2],Next[N<<2],head[N]; 16 int a[N],fa[N],sz[N],b[N]; 17 int n,m,tot,q,size,ans; 18 void add(int u,int v){ 19 ver[++tot]=v,Next[tot]=head[u],head[u]=tot; 20 ver[++tot]=u,Next[tot]=head[v],head[v]=tot; 21 } 22 int L[M],R[M],sum[M],rt[N],cnt; 23 void update(int last,int &now,int l,int r,int x){ 24 sum[now=++cnt]=sum[last]+1; 25 if(l==r) return; 26 int mid=(l+r)>>1; 27 if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x); 28 else L[now]=L[last],update(R[last],R[now],mid+1,r,x); 29 } 30 int query(int u,int v,int lca,int lca_fa,int l,int r,int k){ 31 if(l>=r) return l; 32 int x=sum[L[v]]+sum[L[u]]-sum[L[lca]]-sum[L[lca_fa]]; 33 int mid=(l+r)>>1; 34 if(x>=k) return query(L[u],L[v],L[lca],L[lca_fa],l,mid,k); 35 else return query(R[u],R[v],R[lca],R[lca_fa],mid+1,r,k-x); 36 } 37 inline int hash(int x){ 38 return lower_bound(b+1,b+1+size,x)-b; 39 } 40 int ff(int x){ 41 return fa[x]==x?x:fa[x]=ff(fa[x]); 42 } 43 int st[N][17],d[N],vis[N]; 44 void dfs(int u,int father,int root){ 45 st[u][0]=father; 46 for(int i=1;i<=16;++i) 47 st[u][i]=st[st[u][i-1]][i-1]; 48 ++sz[root]; 49 d[u]=d[father]+1; 50 fa[u]=root; 51 vis[u]=1; 52 update(rt[father],rt[u],1,size,hash(a[u])); 53 for(int i=head[u];i;i=Next[i]){ 54 int v=ver[i]; 55 if(v==father) continue; 56 dfs(v,u,root); 57 } 58 } 59 int LCA(int x,int y){ 60 if(x==y) return x; 61 if(d[x]<d[y]) swap(x,y); 62 for(int i=16;i>=0;--i){ 63 if(d[st[x][i]]>=d[y]) x=st[x][i]; 64 } 65 if(x==y) return x; 66 for(int i=16;i>=0;--i){ 67 if(st[x][i]!=st[y][i]) 68 x=st[x][i],y=st[y][i]; 69 } 70 return st[x][0]; 71 } 72 int main(){ 73 //freopen("testdata.in","r",stdin); 74 int t=read(); 75 n=read(),m=read(),q=read(); 76 for(int i=1;i<=n;++i) 77 a[i]=b[i]=read(),fa[i]=i; 78 sort(b+1,b+1+n); 79 size=unique(b+1,b+1+n)-b-1; 80 for(int i=1;i<=m;++i){ 81 int u=read(),v=read(); 82 add(u,v); 83 } 84 for(int i=1;i<=n;++i) 85 if(!vis[i]) dfs(i,0,i); 86 while(q--){ 87 char ch;int x,y; 88 while(!isupper(ch=getchar())); 89 x=read()^ans,y=read()^ans; 90 if(ch=='Q'){ 91 int k=read()^ans; 92 int lca=LCA(x,y); 93 ans=b[query(rt[x],rt[y],rt[lca],rt[st[lca][0]],1,size,k)]; 94 printf("%d\n",ans); 95 } 96 else{ 97 add(x,y); 98 int u=ff(x),v=ff(y); 99 if(sz[u]<sz[v]) swap(x,y),swap(u,v); 100 dfs(y,x,u); 101 } 102 } 103 return 0; 104 }
洛谷P3066 [USACO12DEC]逃跑的BarnRunning Away From…
要对每一个子树进行操作,怎么做呢?
我们可以直接dfs这棵树,并记录下进入一个点的编号$l[i]$和从这个点出去时的编号$r[i]$
那么这个点的子树的区间一定是$[l[i],r[i]]$
然后直接在树上查询就行了

1 //minamoto 2 #include<bits/stdc++.h> 3 #define N 200005 4 #define M 4000005 5 #define ll long long 6 #define inf 0x3f3f3f3f 7 using namespace std; 8 inline ll read(){ 9 #define num ch-'0' 10 char ch;bool flag=0;ll res; 11 while(!isdigit(ch=getchar())) 12 (ch=='-')&&(flag=true); 13 for(res=num;isdigit(ch=getchar());res=res*10+num); 14 (flag)&&(res=-res); 15 #undef num 16 return res; 17 } 18 int sum[M],L[M],R[M],rt[N]; 19 int ver[N<<1],Next[N<<1],head[N];ll edge[N<<1]; 20 int ls[N],rs[N];ll a[N],b[N]; 21 int n,m,cnt,tot;ll p; 22 void update(int last,int &now,int l,int r,int x){ 23 sum[now=++cnt]=sum[last]+1; 24 if(l==r) return; 25 int mid=(l+r)>>1; 26 if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x); 27 else L[now]=L[last],update(R[last],R[now],mid+1,r,x); 28 } 29 int query(int u,int v,int l,int r,int k){ 30 if(r<k) return sum[v]-sum[u]; 31 if(l>=k) return 0; 32 int mid=(l+r)>>1; 33 if(k<=mid) return query(L[u],L[v],l,mid,k); 34 else return query(R[u],R[v],mid+1,r,k)+sum[L[v]]-sum[L[u]]; 35 } 36 inline void add(int u,int v,ll e){ 37 ver[++tot]=v,Next[tot]=head[u],head[u]=tot,edge[tot]=e; 38 } 39 void dfs(int u,int fa,ll d){ 40 b[ls[u]=++m]=d,a[m]=d; 41 for(int i=head[u];i;i=Next[i]) 42 if(ver[i]!=fa) dfs(ver[i],u,d+edge[i]); 43 rs[u]=m; 44 } 45 int main(){ 46 n=read(),p=read(); 47 for(int u=2;u<=n;++u){ 48 int v=read();ll e=read(); 49 add(v,u,e); 50 } 51 dfs(1,0,0); 52 sort(b+1,b+1+m); 53 m=unique(b+1,b+1+m)-b-1; 54 for(int i=1;i<=n;++i){ 55 int k=lower_bound(b+1,b+1+m,a[i])-b; 56 update(rt[i-1],rt[i],1,m,k); 57 } 58 b[m+1]=inf; 59 for(int i=1;i<=n;++i){ 60 int k=upper_bound(b+1,b+2+m,a[ls[i]]+p)-b; 61 k=query(rt[ls[i]-1],rt[rs[i]],1,m,k); 62 printf("%d\n",k); 63 } 64 return 0; 65 }
bzoj 1803: Spoj1487 Query on a tree III(主席树)。基础的树上查询

1 //minamoto 2 #include<iostream> 3 #include<cstdio> 4 #include<algorithm> 5 using namespace std; 6 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) 7 char buf[1<<21],*p1=buf,*p2=buf; 8 inline int read(){ 9 #define num ch-'0' 10 char ch;bool flag=0;int res; 11 while(!isdigit(ch=getc())) 12 (ch=='-')&&(flag=true); 13 for(res=num;isdigit(ch=getc());res=res*10+num); 14 (flag)&&(res=-res); 15 #undef num 16 return res; 17 } 18 char obuf[1<<24],*o=obuf; 19 inline void print(int x){ 20 if(x>9) print(x/10); 21 *o++=x%10+48; 22 } 23 const int N=100005,M=N*30; 24 int sum[M],L[M],R[M],rt[N]; 25 int ver[N<<1],Next[N<<1],head[N]; 26 int ls[N],rs[N],a[N],b[N],id[N],pos[N]; 27 int n,m,cnt,tot,q; 28 void update(int last,int &now,int l,int r,int x){ 29 sum[now=++cnt]=sum[last]+1; 30 if(l==r) return; 31 int mid=(l+r)>>1; 32 if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x); 33 else L[now]=L[last],update(R[last],R[now],mid+1,r,x); 34 } 35 int query(int u,int v,int l,int r,int k){ 36 if(l>=r) return l; 37 int x=sum[L[v]]-sum[L[u]]; 38 int mid=(l+r)>>1; 39 if(x>=k) return query(L[u],L[v],l,mid,k); 40 else return query(R[u],R[v],mid+1,r,k-x); 41 } 42 inline void add(int u,int v){ 43 ver[++tot]=v,Next[tot]=head[u],head[u]=tot; 44 ver[++tot]=u,Next[tot]=head[v],head[v]=tot; 45 } 46 void dfs(int u,int fa){ 47 a[ls[u]=++m]=b[u],id[m]=u; 48 for(int i=head[u];i;i=Next[i]) 49 if(ver[i]!=fa) dfs(ver[i],u); 50 rs[u]=m; 51 } 52 int main(){ 53 //freopen("testdata.in","r",stdin); 54 n=read(); 55 for(int i=1;i<=n;++i) b[i]=read(); 56 for(int i=1;i<n;++i){ 57 int u,v; 58 u=read(),v=read(); 59 add(u,v); 60 } 61 dfs(1,0); 62 sort(b+1,b+1+m); 63 for(int i=1;i<=n;++i){ 64 int k=lower_bound(b+1,b+1+m,a[i])-b; 65 update(rt[i-1],rt[i],1,m,k); 66 pos[k]=id[i]; 67 } 68 q=read(); 69 while(q--){ 70 int u=read(),k=read(); 71 int ans=pos[query(rt[ls[u]-1],rt[rs[u]],1,m,k)]; 72 print(ans),*o++='\n'; 73 } 74 fwrite(obuf,o-obuf,1,stdout); 75 return 0; 76 }
带修改主席树
我们可以发现,主席树每一棵线段树维护的都是一个前缀和
如果有修改操作,每一次都要对后面的所有的前缀和都进行修改,那样的话时间复杂度就太爆炸了
我们可以考虑一下树状数组
树状数组维护的也是前缀和,但它的每一次修改是$O(log n)$的
他的节点存的并不是前缀和,但我们仍可以用树状数组来求出前缀和
于是我们可以用树状数组的思想来维护,主席树
用树状数组存一下每个节点的位置,每一次修改都按树状数组的方法去修改,也就是说并不需要修改那么多节点
查询的时候,也按树状数组的方法查询就好了
建议对这段话仔细理解,我当初也是懵逼了好久,最后看了zcysky大佬的那篇blog才蓦然醒悟的
拿bzoj1901洛谷P2617 Dynamic Rankings为例
是一个带修改主席树的板子
思路就按我上面所说的

1 //minamoto 2 #include<bits/stdc++.h> 3 #define N 10005 4 using namespace std; 5 inline int read(){ 6 #define num ch-'0' 7 char ch;bool flag=0;int res; 8 while(!isdigit(ch=getchar())) 9 (ch=='-')&&(flag=true); 10 for(res=num;isdigit(ch=getchar());res=res*10+num); 11 (flag)&&(res=-res); 12 #undef num 13 return res; 14 } 15 inline int lowbit(int x){return x&(-x);} 16 int sum[N*600],L[N*600],R[N*600]; 17 int xx[N],yy[N],rt[N],a[N],b[N<<1],ca[N],cb[N],cc[N]; 18 int n,q,m,cnt=0,totx,toty; 19 void update(int last,int &now,int l,int r,int x,int v){ 20 sum[now=++cnt]=sum[last]+v; 21 L[now]=L[last],R[now]=R[last]; 22 if(l==r) return; 23 int mid=(l+r)>>1; 24 if(x<=mid) update(L[last],L[now],l,mid,x,v); 25 else update(R[last],R[now],mid+1,r,x,v); 26 } 27 int query(int l,int r,int q){ 28 if(l==r) return l; 29 int x=0,mid=(l+r)>>1; 30 for(int i=1;i<=totx;++i) x-=sum[L[xx[i]]]; 31 for(int i=1;i<=toty;++i) x+=sum[L[yy[i]]]; 32 if(q<=x){ 33 for(int i=1;i<=totx;++i) xx[i]=L[xx[i]]; 34 for(int i=1;i<=toty;++i) yy[i]=L[yy[i]]; 35 return query(l,mid,q); 36 } 37 else{ 38 for(int i=1;i<=totx;++i) xx[i]=R[xx[i]]; 39 for(int i=1;i<=toty;++i) yy[i]=R[yy[i]]; 40 return query(mid+1,r,q-x); 41 } 42 } 43 void add(int x,int y){ 44 int k=lower_bound(b+1,b+1+m,a[x])-b; 45 for(int i=x;i<=n;i+=lowbit(i)) update(rt[i],rt[i],1,m,k,y); 46 } 47 int main(){ 48 //freopen("testdata.in","r",stdin); 49 n=read(),q=read(); 50 for(int i=1;i<=n;++i) 51 b[++m]=a[i]=read(); 52 for(int i=1;i<=q;++i){ 53 char ch; 54 while(!isupper(ch=getchar())); 55 ca[i]=read(),cb[i]=read(); 56 if(ch=='Q') cc[i]=read();else b[++m]=cb[i]; 57 } 58 sort(b+1,b+1+m); 59 m=unique(b+1,b+1+m)-b-1; 60 for(int i=1;i<=n;++i) add(i,1); 61 for(int i=1;i<=q;++i){ 62 if(cc[i]){ 63 totx=toty=0; 64 for(int j=ca[i]-1;j;j-=lowbit(j)) xx[++totx]=rt[j]; 65 for(int j=cb[i];j;j-=lowbit(j)) yy[++toty]=rt[j]; 66 printf("%d\n",b[query(1,m,cc[i])]); 67 } 68 else{add(ca[i],-1),a[ca[i]]=cb[i],add(ca[i],1);} 69 } 70 return 0; 71 }
还有一道[BZOJ3295] [Cqoi2011]洛谷p3157动态逆序对

1 //minamoto 2 #include<bits/stdc++.h> 3 #define N 100005 4 #define M 5000005 5 #define ll long long 6 using namespace std; 7 inline ll read(){ 8 #define num ch-'0' 9 char ch;bool flag=0;ll res; 10 while(!isdigit(ch=getchar())) 11 (ch=='-')&&(flag=true); 12 for(res=num;isdigit(ch=getchar());res=res*10+num); 13 (flag)&&(res=-res); 14 #undef num 15 return res; 16 } 17 int L[M],R[M],sum[M],rt[N]; 18 int val[N],pos[N],xx[N],yy[N],c[N],a1[N],a2[N]; 19 int n,cnt,q;ll ans=0; 20 inline int lowbit(int x){return x&(-x);} 21 int ask(int x){ 22 int s=0; 23 for(int i=x;i;i-=lowbit(i)) s+=c[i]; 24 return s; 25 } 26 void update(int &now,int l,int r,int k){ 27 if(!now) now=++cnt; 28 ++sum[now]; 29 if(l==r) return; 30 int mid=(l+r)>>1; 31 if(k<=mid) update(L[now],l,mid,k); 32 else update(R[now],mid+1,r,k); 33 } 34 int querysub(int x,int y,int v){ 35 int cntx=0,cnty=0,ans=0;--x; 36 for(int i=x;i;i-=lowbit(i)) xx[++cntx]=rt[i]; 37 for(int i=y;i;i-=lowbit(i)) yy[++cnty]=rt[i]; 38 int l=1,r=n; 39 while(l<r){ 40 int mid=(l+r)>>1; 41 if(v<=mid){ 42 for(int i=1;i<=cntx;++i) ans-=sum[R[xx[i]]]; 43 for(int i=1;i<=cnty;++i) ans+=sum[R[yy[i]]]; 44 for(int i=1;i<=cntx;++i) xx[i]=L[xx[i]]; 45 for(int i=1;i<=cnty;++i) yy[i]=L[yy[i]]; 46 r=mid; 47 } 48 else{ 49 for(int i=1;i<=cntx;++i) xx[i]=R[xx[i]]; 50 for(int i=1;i<=cnty;++i) yy[i]=R[yy[i]]; 51 l=mid+1; 52 } 53 } 54 return ans; 55 } 56 int querypre(int x,int y,int v){ 57 int cntx=0,cnty=0,ans=0;--x; 58 for(int i=x;i;i-=lowbit(i)) xx[++cntx]=rt[i]; 59 for(int i=y;i;i-=lowbit(i)) yy[++cnty]=rt[i]; 60 int l=1,r=n; 61 while(l<r){ 62 int mid=(l+r)>>1; 63 if(v>mid){ 64 for(int i=1;i<=cntx;++i) ans-=sum[L[xx[i]]]; 65 for(int i=1;i<=cnty;++i) ans+=sum[L[yy[i]]]; 66 for(int i=1;i<=cntx;++i) xx[i]=R[xx[i]]; 67 for(int i=1;i<=cnty;++i) yy[i]=R[yy[i]]; 68 l=mid+1; 69 } 70 else{ 71 for(int i=1;i<=cntx;++i) xx[i]=L[xx[i]]; 72 for(int i=1;i<=cnty;++i) yy[i]=L[yy[i]]; 73 r=mid; 74 } 75 } 76 return ans; 77 } 78 int main(){ 79 //freopen("testdata.in","r",stdin); 80 n=read(),q=read(); 81 for(int i=1;i<=n;++i){ 82 val[i]=read(),pos[val[i]]=i; 83 a1[i]=ask(n)-ask(val[i]); 84 ans+=a1[i]; 85 for(int j=val[i];j<=n;j+=lowbit(j)) ++c[j]; 86 } 87 memset(c,0,sizeof(c)); 88 for(int i=n;i;--i){ 89 a2[i]=ask(val[i]-1); 90 for(int j=val[i];j<=n;j+=lowbit(j)) ++c[j]; 91 } 92 while(q--){ 93 printf("%lld\n",ans); 94 int x=read();x=pos[x]; 95 ans-=(a1[x]+a2[x]-querysub(1,x-1,val[x])-querypre(x+1,n,val[x])); 96 for(int j=x;j<=n;j+=lowbit(j)) update(rt[j],1,n,val[x]); 97 } 98 return 0; 99 }
进阶
个人认为主席树的一些好题
【bzoj2653】【middle】
可以加深对主席树的应用,不再只会求第k大之类的套路

1 //minamoto 2 #include<iostream> 3 #include<cstdio> 4 #include<algorithm> 5 using namespace std; 6 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) 7 char buf[1<<21],*p1=buf,*p2=buf; 8 inline int read(){ 9 #define num ch-'0' 10 char ch;bool flag=0;int res; 11 while(!isdigit(ch=getc())) 12 (ch=='-')&&(flag=true); 13 for(res=num;isdigit(ch=getc());res=res*10+num); 14 (flag)&&(res=-res); 15 #undef num 16 return res; 17 } 18 char obuf[1<<24],*o=obuf; 19 inline void print(int x){ 20 if(x>9) print(x/10); 21 *o++=x%10+48; 22 } 23 const int N=20005,M=N*30; 24 int n,Pre,q,cnt; 25 int rt[N],p[5]; 26 struct node{ 27 int l,r,lmx,rmx,sum; 28 }t[M],op; 29 struct data{ 30 int x,id; 31 inline bool operator <(const data &b)const 32 {return x<b.x;} 33 }a[N]; 34 inline void pushup(int x){ 35 t[x].sum=t[t[x].l].sum+t[t[x].r].sum; 36 t[x].lmx=max(t[t[x].l].lmx,t[t[x].l].sum+t[t[x].r].lmx); 37 t[x].rmx=max(t[t[x].r].rmx,t[t[x].r].sum+t[t[x].l].rmx); 38 } 39 void build(int &now,int l,int r){ 40 now=++cnt; 41 if(l==r){t[now].lmx=t[now].rmx=t[now].sum=1;return;} 42 int mid=(l+r)>>1; 43 build(t[now].l,l,mid); 44 build(t[now].r,mid+1,r); 45 pushup(now); 46 } 47 void update(int last,int &now,int l,int r,int k){ 48 now=++cnt; 49 if(l==r){t[now].lmx=t[now].rmx=t[now].sum=-1;return;} 50 int mid=(l+r)>>1; 51 if(k<=mid) t[now].r=t[last].r,update(t[last].l,t[now].l,l,mid,k); 52 else t[now].l=t[last].l,update(t[last].r,t[now].r,mid+1,r,k); 53 pushup(now); 54 } 55 node merge(node x,node y){ 56 node z; 57 z.sum=x.sum+y.sum; 58 z.lmx=max(x.lmx,x.sum+y.lmx); 59 z.rmx=max(y.rmx,y.sum+x.rmx); 60 return z; 61 } 62 node find(int x,int l,int r,int y,int z){ 63 if(y>z) return op; 64 if(l==y&&r==z) return t[x]; 65 int mid=(l+r)>>1; 66 if(z<=mid) return find(t[x].l,l,mid,y,z); 67 else if(y>mid) return find(t[x].r,mid+1,r,y,z); 68 else return merge(find(t[x].l,l,mid,y,mid),find(t[x].r,mid+1,r,mid+1,z)); 69 } 70 int query(int x){ 71 return find(rt[x],1,n,p[1],p[2]).rmx+find(rt[x],1,n,p[2]+1,p[3]-1).sum+find(rt[x],1,n,p[3],p[4]).lmx; 72 } 73 int main(){ 74 //freopen("testdata.in","r",stdin); 75 n=read(); 76 for(int i=1;i<=n;++i) a[i].x=read(),a[i].id=i; 77 sort(a+1,a+1+n); 78 build(rt[1],1,n); 79 for(int i=2;i<=n;++i) update(rt[i-1],rt[i],1,n,a[i-1].id); 80 q=read(); 81 while(q--){ 82 int x=read(),y=read(),z=read(),k=read(); 83 p[1]=(x+Pre)%n+1,p[2]=(y+Pre)%n+1,p[3]=(z+Pre)%n+1,p[4]=(k+Pre)%n+1; 84 sort(p+1,p+5); 85 int l=1,r=n,ans=1; 86 while(l<=r){ 87 int mid=(l+r)>>1; 88 int f=query(mid); 89 if(f>=0) ans=mid,l=mid+1; 90 else r=mid-1; 91 } 92 Pre=a[ans].x; 93 print(a[ans].x),*o++='\n'; 94 } 95 fwrite(obuf,o-obuf,1,stdout); 96 return 0; 97 }
hdu 4348 To the moon
主席树的区间修改,应该算是真正的可持久化?

1 //minamoto 2 #include<bits/stdc++.h> 3 #define ll long long 4 using namespace std; 5 const int N=100005,M=N*30; 6 int n,m,cnt,rt[N]; 7 int L[M],R[M];ll sum[M],add[M]; 8 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) 9 char buf[1<<21],*p1=buf,*p2=buf; 10 inline ll read(){ 11 #define num ch-'0' 12 char ch;bool flag=0;ll res; 13 while(!isdigit(ch=getc())) 14 (ch=='-')&&(flag=true); 15 for(res=num;isdigit(ch=getc());res=res*10+num); 16 (flag)&&(res=-res); 17 #undef num 18 return res; 19 } 20 void build(int &now,int l,int r){ 21 add[now=++cnt]=0; 22 if(l==r) return (void)(sum[now]=read()); 23 int mid=(l+r)>>1; 24 build(L[now],l,mid); 25 build(R[now],mid+1,r); 26 sum[now]=sum[L[now]]+sum[R[now]]; 27 } 28 void update(int last,int &now,int l,int r,int ql,int qr,int x){ 29 now=++cnt; 30 L[now]=L[last],R[now]=R[last],add[now]=add[last],sum[now]=sum[last]; 31 sum[now]+=1ll*x*(qr-ql+1); 32 if(ql==l&&qr==r) return (void)(add[now]+=x); 33 int mid=(l+r)>>1; 34 if(qr<=mid) update(L[last],L[now],l,mid,ql,qr,x); 35 else if(ql>mid) update(R[last],R[now],mid+1,r,ql,qr,x); 36 else return (void)(update(L[last],L[now],l,mid,ql,mid,x),update(R[last],R[now],mid+1,r,mid+1,qr,x)); 37 } 38 ll query(int now,int l,int r,int ql,int qr){ 39 if(l==ql&&r==qr) return sum[now]; 40 int mid=(l+r)>>1; 41 ll res=1ll*add[now]*(qr-ql+1); 42 if(qr<=mid) res+=query(L[now],l,mid,ql,qr); 43 else if(ql>mid) res+=query(R[now],mid+1,r,ql,qr); 44 else res+=query(L[now],l,mid,ql,mid)+query(R[now],mid+1,r,mid+1,qr); 45 return res; 46 } 47 int main(){ 48 //freopen("testdata.in","r",stdin); 49 n=read(),m=read(); 50 cnt=-1; 51 build(rt[0],1,n); 52 int now=0; 53 while(m--){ 54 char ch;int l,r,x; 55 while(!isupper(ch=getc())); 56 switch(ch){ 57 case 'C':{ 58 l=read(),r=read(),x=read(); 59 ++now; 60 update(rt[now-1],rt[now],1,n,l,r,x); 61 break; 62 } 63 case 'Q':{ 64 l=read(),r=read(); 65 printf("%lld\n",query(rt[now],1,n,l,r)); 66 break; 67 } 68 case 'H':{ 69 l=read(),r=read(),x=read(); 70 printf("%lld\n",query(rt[x],1,n,l,r)); 71 break; 72 } 73 case 'B':{ 74 now=read(); 75 cnt=rt[now+1]-1; 76 break; 77 } 78 } 79 } 80 return 0; 81 }
鉴于本人十分弱鸡,可能讲的不是非常清楚,欢迎大家在下面补充