前言
本博客用於總結聯賽中常考的數據結構和樹論,大概會寫一點樹鏈剖分,\(dsu on tree\),樹狀數組,線段樹,平衡樹,dfs序,樹上差分等等。
雖然對於聯賽來說,數據結構的意義更多是騙分,但畢竟\(CSP\)不同於\(NOIP\),萬一就想標新立異呢?
也許會附帶一些簡要的講解,聯賽后有時間會寫詳細的講解,但是我真的很懶,所以請不要有過大的期望。
\(\texttt{Talk is cheap.Let me show you the code.}\)
樹狀數組
樹狀數組是一種支持區間查詢,區間更新,單點查詢,單點更新,區間最值,逆序對,區間不同的個數等多種操作的數據結構,復雜度為\(O(nlogn)\),優點是代碼簡介,常數非常小,但是不是很好理解。(說實話我覺得樹狀數組是最好理解的)樹狀數組的實現和位運算密切相關,也就是\(lowbit\)運算,樹狀數組最核心的思想是前綴和。
區間查詢與單點更新
#include<cstdio>
int n,m,c[500005];
int lowbit(int x)
{
return x&-x;
}
int query(int pos)
{
int ans=0;
for(int i=pos;i>=1;i-=lowbit(i))
ans+=c[i];
return ans;
}
void add(int pos,int x)
{
for(int i=pos;i<=n;i+=lowbit(i))
c[i]+=x;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
add(i,x);
}
while(m--)
{
int opt,x,y;
scanf("%d%d%d",&opt,&x,&y);
if(opt==1)add(x,y);
else printf("%d\n",query(y)-query(x-1));
}
return 0;
}
區間更新與單點查詢
這個也很簡單,但是我們不能用普通的樹狀數組來做,初始化的時候並不是在記錄初始數組而是差分數組,這樣就可以用樹狀數組1中區間查詢的套路來查詢單點,這是差分的性質,而至於區間更新\([l,r]\),只需要在\(r+1\)的地方加上這個值,在\(l\)的地方減去這個值,根據差分的性質,最后就能求出正確答案,重點就是熟悉差分。
有一個小細節,區間更新不能直接像樹狀數組\(1\)里面寫\(l-1\)和\(r\),應該是\(r+1\)和\(l\),至於為什么,希望讀者自己去思考,深入理解樹狀數組的實現。
#include<cstdio>
int n,m,c[500005],a[500005];
int lowbit(int x){return x&-x;}
void add(int pos,int x)
{
for(int i=pos;i<=n;i+=lowbit(i))
c[i]+=x;
}
int query(int pos)
{
int ans=0;
for(int i=pos;i>=1;i-=lowbit(i))
ans+=c[i];
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]),add(i,a[i]-a[i-1]);
while(m--)
{
int opt,x,y,k;
scanf("%d%d",&opt,&x);
if(opt==1)
{
scanf("%d%d",&y,&k);
add(x,k),add(y+1,-k);
}
else printf("%d\n",query(x));
}
return 0;
}
求逆序對
想不到吧,還能整這個!
這個思想非常簡單,不講了,只要懂樹狀數組就能理解。核心就是倒敘\(n~1\)循環,每次給\(ans\)累計\(query(a[i]-1)\)的答案,再在\(a[i]\)的位置加\(1\),手動模擬一下就可以了,注意一下離散化的排序問題。
#include<bits/stdc++.h>
using namespace std;
#define int long long
int n,a[1000005],c[1000005],ans;
int lowbit(int x){return x&-x;}
int query(int pos){int ans=0;for(int i=pos;i>=1;i-=lowbit(i))ans+=c[i];return ans;}
void add(int pos,int x){for(int i=pos;i<=n;i+=lowbit(i))c[i]+=x;}
struct node{int tmp,num;}b[1000005];
bool cmp(node a,node b){return a.tmp<b.tmp||(a.tmp==b.tmp&&a.num<b.num);}
signed main()
{
ios::sync_with_stdio(false);
cin>>n;
for(int i=1;i<=n;i++){cin>>b[i].tmp;b[i].num=i;}
sort(b+1,b+1+n,cmp);
for(int i=1;i<=n;i++)a[b[i].num]=i;
for(int i=n;i>=1;i--)
ans+=query(a[i]-1),add(a[i],1);
cout<<ans<<endl;
return 0;
}
區間不同值
就是求某個區間內一共有多少個不相同的元素。
如果要用樹狀數組來求這個的話,限制非常多,因為必須使用離線操作,這也就意味着我們無法進行更新操作。
我們考慮一個序列:\(\texttt{1 2 3 4 3 5}\),會發現,如果我們要查詢一個區間\([l,r]\),比如\(l=3,r=6\),此時區間內有\(2\)個元素為\(3\),但實際上影響我們最終答案的只與后面的那個\(3\)有關。
我們考慮正在查詢一個區間\([l,r]\),從\(1\)循環到\(r\),每遇到一個\(a[i]\),就\(add(i,1)\),而如果\(a[i]\)這個數在前面的位置\(pre\)出現過,就\(add(pre,-1)\),再更新\(a[j]\)出現的位置為\(j\),最后直接利用前綴和來查詢就可以了。
但你還要考慮一個問題,如果在前面出現的\(r\)比在后面出現的\(r\)要小,那么就會出現錯誤情況,所以你要把所有的\(r\)從小到大排序。
我覺得講樹狀數組很難,因為本來就不適合講,適合自己畫圖、模擬去理解。
#include<iostream>
#include<algorithm>
using namespace std;
int n,a[1000005],vis[1000005],p[1000005],c[1000005],q;
struct node{int l,r,ask;}b[1000005];
int lowbit(int x){return x&-x;}
void add(int pos,int x){for(int i=pos;i<=n;i+=lowbit(i))c[i]+=x;}
int query(int pos)
{
int ans=0;
for(int i=pos;i>=1;i-=lowbit(i))ans+=c[i];
return ans;
}
bool cmp(node a,node b){return a.r<b.r;}
int main()
{
ios::sync_with_stdio(false);
cin>>n;
for(int i=1;i<=n;i++)cin>>a[i];
cin>>q;
for(int i=1;i<=q;i++){cin>>b[i].l>>b[i].r;b[i].ask=i;}
sort(b+1,b+1+q,cmp);
int sta=1;
for(int i=1;i<=q;i++)
{
for(int j=sta;j<=b[i].r;j++)
{
add(j,1);
if(vis[a[j]])add(vis[a[j]],-1);
vis[a[j]]=j;
}
p[b[i].ask]=query(b[i].r)-query(b[i].l-1);
sta=b[i].r+1;
}
for(int i=1;i<=q;i++)cout<<p[i]<<"\n";
return 0;
}
樹狀數組講到這里就可以了,因為確實功能比較少並且理解復雜。
線段樹
線段樹是一種非常優秀的數據結構,復雜度\(O(nlogn)\),雖然常數比較大,代碼也比較長,只要理解了,還是好寫,功能相對非常多,樹狀數組能做的事情它都能做,除此之外線段樹還支持好多好多操作以及優化,慢慢來吧。
區間/單點的加法/查詢
線段樹支持這四種操作同時進行,雖然樹狀數組也支持,但寫起來很麻煩,其實線段樹理解了,也是很好寫的。
當然了,細節很重要,比如\(pushdown\)和\(pushup\)的運用等等,而且也要注意常數的問題,結構體的常數一般會大一點,\(\texttt{zkw}\)線段樹的常數很小,但我沒有學,所以我一般習慣寫不帶結構體的線段樹。
#include<bits/stdc++.h>
using namespace std;
#define int long long
int n,m,a[100005],t[400005],laz[400005];
void pushdown(int rt,int l,int r)
{
if(!laz[rt])return;
int len=(r-l+1);
laz[rt<<1]+=laz[rt];
laz[rt<<1|1]+=laz[rt];
t[rt<<1]+=(len-(len>>1))*laz[rt];
t[rt<<1|1]+=(len>>1)*laz[rt];
laz[rt]=0;
}
void pushup(int rt){t[rt]=(t[rt<<1]+t[rt<<1|1]);}
void build(int l,int r,int rt)
{
if(l==r){t[rt]=a[l];return;}
int mid=(l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
pushup(rt);
}
void update(int rt,int l,int r,int la,int ra,int x)
{
if(la<=l&&r<=ra)
{
t[rt]+=(r-l+1)*x;
laz[rt]+=x;
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(la<=mid)update(rt<<1,l,mid,la,ra,x);
if(ra>mid)update(rt<<1|1,mid+1,r,la,ra,x);
pushup(rt);
}
int query(int rt,int l,int r,int la,int ra)
{
if(la<=l&&r<=ra)return t[rt];
pushdown(rt,l,r);
int mid=(l+r)>>1;
int ans=0;
if(la<=mid)ans+=query(rt<<1,l,mid,la,ra);
if(ra>mid)ans+=query(rt<<1|1,mid+1,r,la,ra);
return ans;
}
signed main()
{
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,n,1);
while(m--)
{
int opt,x,y,k;
cin>>opt>>x>>y;
if(opt&1)cin>>k,update(1,1,n,x,y,k);
else cout<<query(1,1,n,x,y)<<"\n";
}
return 0;
}
區間/單點的乘法/加法/更新
其實相對線段樹\(1\),只是多了一個乘法的操作而已,多打一個\(lazy\)標記就可以了。
說難也難,說簡單也簡單,除了乘法加法的優先度問題之外,其他沒有什么不同。
反正就是注意細節啦,因為在我區間乘法的時候忘記了更新加法的\(lazytag\)還\(WA\)了一次,所以第一要深入理解,第二要注意細節。
#include<bits/stdc++.h>
using namespace std;
#define int long long
int n,m,mod,t[400005],lazp[400005],lazm[400005],a[400005];
void pushup(int rt){t[rt]=(t[rt<<1]+t[rt<<1|1])%mod;}
void pushdown(int rt,int l,int r)
{
int len=r-l+1;
t[rt<<1]=(t[rt<<1]*lazm[rt]+lazp[rt]*(len-(len>>1)))%mod;
t[rt<<1|1]=(t[rt<<1|1]*lazm[rt]+lazp[rt]*(len>>1))%mod;
lazm[rt<<1]=lazm[rt]*lazm[rt<<1]%mod;
lazm[rt<<1|1]=lazm[rt<<1|1]*lazm[rt]%mod;
lazp[rt<<1]=(lazp[rt<<1]*lazm[rt]+lazp[rt])%mod;
lazp[rt<<1|1]=(lazp[rt<<1|1]*lazm[rt]+lazp[rt])%mod;
lazp[rt]=0,lazm[rt]=1;
}
void build(int rt,int l,int r)
{
lazm[rt]=1;
if(l==r){t[rt]=a[l]%mod;return;}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void update_plus(int rt,int l,int r,int la,int ra,int x)
{
if(la<=l&&r<=ra)
{
t[rt]=(t[rt]+(r-l+1)*x)%mod;
lazp[rt]=(lazp[rt]+x)%mod;
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(la<=mid)update_plus(rt<<1,l,mid,la,ra,x);
if(ra>mid)update_plus(rt<<1|1,mid+1,r,la,ra,x);
pushup(rt);
}
void update_mul(int rt,int l,int r,int la,int ra,int x)
{
if(la<=l&&r<=ra)
{
t[rt]=(t[rt]*x%mod);
lazm[rt]=lazm[rt]*x%mod;
lazp[rt]=lazp[rt]*x%mod;//記得更新加法的tag
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(la<=mid)update_mul(rt<<1,l,mid,la,ra,x);
if(ra>mid)update_mul(rt<<1|1,mid+1,r,la,ra,x);
pushup(rt);
}
int query(int rt,int l,int r,int la,int ra)
{
if(la<=l&&r<=ra)return t[rt]%mod;
pushdown(rt,l,r);
int mid=(l+r)>>1;
int ans=0;
if(la<=mid)ans=(ans+query(rt<<1,l,mid,la,ra))%mod;
if(ra>mid)ans=(ans+query(rt<<1|1,mid+1,r,la,ra))%mod;
return ans;
}
signed main()
{
ios::sync_with_stdio(false);
cin>>n>>m>>mod;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
while(m--)
{
int opt,x,y,k;
cin>>opt>>x>>y;
if(opt==1)cin>>k,update_mul(1,1,n,x,y,k);
else if(opt==2)cin>>k,update_plus(1,1,n,x,y,k);
else cout<<query(1,1,n,x,y)<<"\n";
}
return 0;
}
區間最值
求區間最值的時候常常會伴隨着區間更新,那么我放一道裸題。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int n,t[800005],a[800005],m;
void pushup(int rt){t[rt]=max(t[rt<<1],t[rt<<1|1]);}
void build(int rt,int l,int r)
{
if(l==r){t[rt]=a[l];return;}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,int pos,int x)
{
if(l==r){t[rt]=x;return;}
int mid=(l+r)>>1;
if(pos<=mid)update(rt<<1,l,mid,pos,x);
else update(rt<<1|1,mid+1,r,pos,x);
pushup(rt);
}
int query(int rt,int l,int r,int la,int ra)
{
if(la<=l&&r<=ra)return t[rt];
int mid=(l+r)>>1;
int ans=0;
if(la<=mid)ans=max(ans,query(rt<<1,l,mid,la,ra));
if(ra>mid)ans=max(ans,query(rt<<1|1,mid+1,r,la,ra));
return ans;
}
int main()
{
// ios::sync_with_stdio(false);
while(scanf("%d%d",&n,&m)!=EOF)
{
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
build(1,1,n);
while(m--)
{
char op[2];
int x,y;
scanf("%s%d%d",op,&x,&y);
if(op[0]=='U')update(1,1,n,x,y);
else printf("%d\n",query(1,1,n,x,y));
}
}
return 0;
}
樹鏈剖分
樹鏈剖分是一類解決樹上詢問的算法。
具體來說,我們需要引入幾個概念。
重兒子:每一個點的子孫中,子樹最大的兒子。
輕兒子:除了重兒子之外的所有子孫。
重邊:每個點和重兒子相連的那一條邊。
輕邊:除了重邊之外的所有邊。
重邊相連成重鏈,輕鏈最多只有一條輕邊組成。
根據這張圖可以看出來,加粗的是重邊,標紅的是輕兒子。
根據這樣,我們就把一棵樹剖成了一條一條的鏈,這個鏈的順序可以使用\(dfs序\)來維護,樹轉線性過后,要維護這棵樹上的各種信息,就可以使用數據結構了,我用得比較多的是線段樹。
放一道題 模板 樹鏈剖分
#include<bits/stdc++.h>
#define p mod
using namespace std;
int n,m,q,mod,rt,t[800005],d[200005],laz[800005],siz[100005],son[100005],top[100005],id[100005],idx,cnt,dfn[100005],fa[100005],h[200005],w[100005],wn[100005];
struct node{int v,nxt;}e[200005];
void add(int u,int v)
{
e[++cnt].v=v;
e[cnt].nxt=h[u];
h[u]=cnt;
}
void dfs1(int u,int f)
{
d[u]=d[f]+1;
siz[u]=1;
fa[u]=f;
for(int i=h[u];i!=-1;i=e[i].nxt)
{
int v=e[i].v;
if(v==f)continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int f)
{
id[u]=++idx;
// dfn[idx]=u;
wn[idx]=w[u];
top[u]=f;
if(!son[u])return;
dfs2(son[u],f);
for(int i=h[u];i!=-1;i=e[i].nxt)
{
int v=e[i].v;
if(v==fa[u]||v==son[u])continue;//此處的v應該和fa[u]比較!!!!!!!!
dfs2(v,v);
}
}
void pushup(int rt){t[rt]=(t[rt<<1]+t[rt<<1|1])%mod;}
void pushdown(int rt,int l,int r)
{
if(!laz[rt])return;
int len=r-l+1;
laz[rt<<1]=(laz[rt<<1]+laz[rt])%mod;
laz[rt<<1|1]=(laz[rt<<1|1]+laz[rt])%mod;
t[rt<<1]=(t[rt<<1]+(len-(len>>1))*laz[rt])%mod;
t[rt<<1|1]=(t[rt<<1|1]+(len>>1)*laz[rt])%mod;
// t[rt<<1]%=mod,t[rt<<1|1]%=mod;
laz[rt]=0;
}
void build(int rt,int l,int r)
{
if(l==r){t[rt]=wn[l]%mod;return;}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,int la,int ra,int x)
{
if(la<=l&&r<=ra)
{
t[rt]=(t[rt]+(r-l+1)*x)%mod;
laz[rt]=(laz[rt]+x)%mod;
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(la<=mid)update(rt<<1,l,mid,la,ra,x);
if(ra>mid)update(rt<<1|1,mid+1,r,la,ra,x);
pushup(rt);
}
int query(int rt,int l,int r,int la,int ra)
{
if(la<=l&&r<=ra)return t[rt]%mod;
pushdown(rt,l,r);
int ans=0;
int mid=(l+r)>>1;
if(la<=mid)ans=(ans+query(rt<<1,l,mid,la,ra))%mod;
if(ra>mid)ans=(ans+query(rt<<1|1,mid+1,r,la,ra))%mod;
return ans%mod;
}
void point_upt(int u,int v,int x)
{
while(top[u]!=top[v])
{
if(d[top[u]]<d[top[v]])swap(u,v);
update(1,1,n,id[top[u]],id[u],x);
u=fa[top[u]];
}
if(d[u]>d[v])swap(u,v);
update(1,1,n,id[u],id[v],x);
}
int point_query(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(d[top[u]]<d[top[v]])swap(u,v);
ans=(ans+query(1,1,n,id[top[u]],id[u]))%mod;
u=fa[top[u]];
}
if(d[u]>d[v])swap(u,v);
ans=(ans+query(1,1,n,id[u],id[v]))%mod;
return ans%mod;
}
void tree_upt(int rt,int x){update(1,1,n,id[rt],id[rt]+siz[rt]-1,x);}
int tree_query(int rt){return query(1,1,n,id[rt],id[rt]+siz[rt]-1)%mod;}
inline void read(int&x)
{
x=0;char c=getchar();
while(c<'0'||c>'9')c=getchar();
while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();
}
int main()
{
memset(h,-1,sizeof h);
read(n),read(m),read(rt),read(mod);
for(int i=1;i<=n;i++)read(w[i]);
for(int i=1,u,v;i<n;i++)read(u),read(v),add(u,v),add(v,u);
dfs1(rt,0),dfs2(rt,rt),build(1,1,n);
while(m--)
{
int opt,x,y,z;
read(opt);
if(opt==1)read(x),read(y),read(z),point_upt(x,y,z);
if(opt==2)read(x),read(y),printf("%d\n",point_query(x,y));
if(opt==3)read(x),read(y),tree_upt(x,y);
if(opt==4)read(x),printf("%d\n",tree_query(x));
}
return 0;
}
Dsu on tree
\(\texttt{Dsu on tree}\),樹上啟發式合並,也就是\(lxl\)口中的靜態鏈分治。
建議去看看窩的學長的\(Blog\) pzy神仙!(破音)
考慮這樣一類樹上問題:
-
無修改操作,允許詢問離線。
-
對子樹信息進行統計。
你看到這道題的時候是不是很懵?我也是。這道題是支持離線詢問的,可以跑樹上莫隊,也可以跑樹狀數組(參照\(HH\)的項鏈),但作為\(dsu\)的模板題,還是要負責地講一講\(dsu\)。
具體做法:
-
定義一個全局的貢獻統計\(cnt[i]\) 下標\(i\)表示這種顏色出現了多少次
-
利用樹剖性質
-
遍歷輕邊並記錄輕兒子的貢獻(與此同時要記錄下輕兒子點的答案) 再清除輕兒子的貢獻
-
遍歷重兒子 記錄並且保留貢獻
-
再次暴力統計輕兒子的貢獻
那么就產生了一個困擾我非常久的問題,前后問了\(SimonSu\),\(koalawy\),\(pzy\)神仙,最后\(pzy\)無可奈何地給我手\(\%\)了一下,終於聽懂了...
為什么我們不直接統計輕兒子再統計重兒子/先統計重兒子再統計輕兒子呢?
我們如果直接統計,那么輕/重兒子的貢獻被保存在了\(cnt\)中,由於\(cnt\)的性質是全局的,所以會對我們接下來進行重/輕兒子的統計出現影響。
而為什么我們要保留重兒子而不是輕兒子呢?因為重兒子很多,輕兒子已經被證明了是不多於\(O(logn)\)條的,所以我們選擇暴力統計輕兒子而不是重兒子,最終的復雜度可以證明是\(O(nlogn)\)。
\(\texttt{ p z y t x d y !!!!!!!!!!!!!!!!!}\)
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int n,m,c[200005],h[200005],cnt[200005],ans[200005],top,siz[200005],son[200005],vis[200005],num;
struct node{int v,nxt;}e[200005];
void add(int u,int v)
{
e[++num].v=v;
e[num].nxt=h[u];
h[u]=num;
}
void dfs(int u,int fa)
{
siz[u]=1;
for(int i=h[u];i!=-1;i=e[i].nxt)
{
int v=e[i].v;
if(v==fa)continue;
dfs(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void calc(int u,int fa,int val)
{
if(val>0)
{
if(!cnt[c[u]])top++;
cnt[c[u]]++;
}
else
{
if(cnt[c[u]]<=1)top--;
cnt[c[u]]--;
}
for(int i=h[u];i!=-1;i=e[i].nxt)
{
int v=e[i].v;
if(v==fa||vis[v])continue;
calc(v,u,val);
}
}
void dsu(int u,int fa,int val)
{
for(int i=h[u];i!=-1;i=e[i].nxt)
{
int v=e[i].v;
if(v==fa||v==son[u])continue;
dsu(v,u,0);
}
if(son[u])dsu(son[u],u,1),vis[son[u]]=1;//當我們遍歷了全部的輕兒子 如果有一個重兒子沒有被遍歷 就遍歷qwq
calc(u,fa,1),vis[son[u]]=0; //累計答案
ans[u]=top;//答案下傳
if(!val)calc(u,fa,-1);//如果當前節點是輕兒子 那么我們需要減去答案
}
inline void read(int&x)
{
x=0;char c=getchar();
while(c<'0'||c>'9')c=getchar();
while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();
}
int main()
{
memset(h,-1,sizeof h);
read(n);
for(int i=1,u,v;i<n;i++)read(u),read(v),add(u,v),add(v,u);
for(int i=1;i<=n;i++)read(c[i]);
dfs(1,0),dsu(1,0,1);
read(m);
for(int i=1,ask;i<=m;i++)read(ask),printf("%d\n",ans[ask]);
return 0;
}