寫在前面
首先,在學樹鏈剖分之前最好先把 LCA、樹形DP、DFS序 這三個知識點學了
emm還有必備的 鏈式前向星、線段樹 也要先學了。
如果這三個知識點沒掌握好的話,樹鏈剖分難以理解也是當然的。
樹鏈剖分
樹鏈剖分 就是對一棵樹分成幾條鏈,把樹形變為線性,減少處理難度
需要處理的問題:
- 將樹從x到y結點最短路徑上所有節點的值都加上z
- 求樹從x到y結點最短路徑上所有節點的值之和
- 將以x為根節點的子樹內所有節點值都加上z
- 求以x為根節點的子樹內所有節點值之和
目錄:
- 重兒子:對於每一個非葉子節點,它的兒子中 以那個兒子為根的子樹節點數最大的兒子 為該節點的重兒子 (Ps: 感謝@shzr大佬指出我此句話的表達不嚴謹qwq, 已修改)
- 輕兒子:對於每一個非葉子節點,它的兒子中 非重兒子 的剩下所有兒子即為輕兒子
- 葉子節點沒有重兒子也沒有輕兒子(因為它沒有兒子。。)
- 重邊:一個父親連接他的重兒子的邊稱為重邊 //原寫法:連接任意兩個重兒子的邊叫做重邊
- 輕邊:剩下的即為輕邊
- 重鏈:相鄰重邊連起來的 連接一條重兒子 的鏈叫重鏈
- 對於葉子節點,若其為輕兒子,則有一條以自己為起點的長度為1的鏈
- 每一條重鏈以輕兒子為起點

這個dfs要處理幾件事情:
- 標記每個點的深度dep[]
- 標記每個點的父親fa[]
- 標記每個非葉子節點的子樹大小(含它自己)
- 標記每個非葉子節點的重兒子編號son[]
inline void dfs1(int x,int f,int deep){//x當前節點,f父親,deep深度
dep[x]=deep;//標記每個點的深度
fa[x]=f;//標記每個點的父親
siz[x]=1;//標記每個非葉子節點的子樹大小
int maxson=-1;//記錄重兒子的兒子數
for(Rint i=beg[x];i;i=nex[i]){
int y=to[i];
if(y==f)continue;//若為父親則continue
dfs1(y,x,deep+1);//dfs其兒子
siz[x]+=siz[y];//把它的兒子數加到它身上
if(siz[y]>maxson)son[x]=y,maxson=siz[y];//標記每個非葉子節點的重兒子編號
}
}//變量解釋見最下面
這個dfs2也要預處理幾件事情
- 標記每個點的新編號
- 賦值每個點的初始值到新編號上
- 處理每個點所在鏈的頂端
- 處理每條鏈
順序:先處理重兒子再處理輕兒子,理由后面說
inline void dfs2(int x,int topf){//x當前節點,topf當前鏈的最頂端的節點
id[x]=++cnt;//標記每個點的新編號
wt[cnt]=w[x];//把每個點的初始值賦到新編號上來
top[x]=topf;//這個點所在鏈的頂端
if(!son[x])return;//如果沒有兒子則返回
dfs2(son[x],topf);//按先處理重兒子,再處理輕兒子的順序遞歸處理
for(Rint i=beg[x];i;i=nex[i]){
int y=to[i];
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);//對於每一個輕兒子都有一條從它自己開始的鏈
}
}//變量解釋見最下面
Attention 重要的來了!!!
前面說到dfs2的順序是先處理重兒子再處理輕兒子
我們來模擬一下:

- 因為順序是先重再輕,所以每一條重鏈的新編號是連續的
- 因為是dfs,所以每一個子樹的新編號也是連續的
現在回顧一下我們要處理的問題
- 處理任意兩點間路徑上的點權和
- 處理一點及其子樹的點權和
- 修改任意兩點間路徑上的點權
- 修改一點及其子樹的點權
1、當我們要處理任意兩點間路徑時:
設所在鏈頂端的深度更深的那個點為x點
- ans加上x點到x所在鏈頂端 這一段區間的點權和
- 把x跳到x所在鏈頂端的那個點的上面一個點
不停執行這兩個步驟,直到兩個點處於一條鏈上,這時再加上此時兩個點的區間和即可

這時我們注意到,我們所要處理的所有區間均為連續編號(新編號),於是想到線段樹,用線段樹處理連續編號區間和
每次查詢時間復雜度為\(O( \log^2n)\)
inline int qRange(int x,int y){
int ans=0;
while(top[x]!=top[y]){//當兩個點不在同一條鏈上
if(dep[top[x]]<dep[top[y]])swap(x,y);//把x點改為所在鏈頂端的深度更深的那個點
res=0;
query(1,1,n,id[top[x]],id[x]);//ans加上x點到x所在鏈頂端 這一段區間的點權和
ans+=res;
ans%=mod;//按題意取模
x=fa[top[x]];//把x跳到x所在鏈頂端的那個點的上面一個點
}
//直到兩個點處於一條鏈上
if(dep[x]>dep[y])swap(x,y);//把x點深度更深的那個點
res=0;
query(1,1,n,id[x],id[y]);//這時再加上此時兩個點的區間和即可
ans+=res;
return ans%mod;
}//變量解釋見最下面
2、處理一點及其子樹的點權和:
想到記錄了每個非葉子節點的子樹大小(含它自己),並且每個子樹的新編號都是連續的
於是直接線段樹區間查詢即可
時間復雜度為\(O( \log n)\)
inline int qSon(int x){
res=0;
query(1,1,n,id[x],id[x]+siz[x]-1);//子樹區間右端點為id[x]+siz[x]-1
return res;
}
當然,區間修改就和區間查詢一樣的啦~~
inline void updRange(int x,int y,int k){
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,1,n,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(1,1,n,id[x],id[y],k);
}
inline void updSon(int x,int k){
update(1,1,n,id[x],id[x]+siz[x]-1,k);
}//變量解釋見最下面
既然前面說到要用線段樹,那么按題意建樹就可以啦!
不過,建樹這一步當然是在處理問題之前哦~
AC代碼:
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#define Rint register int
#define mem(a,b) memset(a,(b),sizeof(a))
#define Temp template<typename T>
using namespace std;
typedef long long LL;
Temp inline void read(T &x){
x=0;T w=1,ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(isdigit(ch))x=(x<<3)+(x<<1)+(ch^'0'),ch=getchar();
x=x*w;
}
#define mid ((l+r)>>1)
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define len (r-l+1)
const int maxn=200000+10;
int n,m,r,mod;
//見題意
int e,beg[maxn],nex[maxn],to[maxn],w[maxn],wt[maxn];
//鏈式前向星數組,w[]、wt[]初始點權數組
int a[maxn<<2],laz[maxn<<2];
//線段樹數組、lazy操作
int son[maxn],id[maxn],fa[maxn],cnt,dep[maxn],siz[maxn],top[maxn];
//son[]重兒子編號,id[]新編號,fa[]父親節點,cnt dfs_clock/dfs序,dep[]深度,siz[]子樹大小,top[]當前鏈頂端節點
int res=0;
//查詢答案
inline void add(int x,int y){//鏈式前向星加邊
to[++e]=y;
nex[e]=beg[x];
beg[x]=e;
}
//-------------------------------------- 以下為線段樹
inline void pushdown(int rt,int lenn){
laz[rt<<1]+=laz[rt];
laz[rt<<1|1]+=laz[rt];
a[rt<<1]+=laz[rt]*(lenn-(lenn>>1));
a[rt<<1|1]+=laz[rt]*(lenn>>1);
a[rt<<1]%=mod;
a[rt<<1|1]%=mod;
laz[rt]=0;
}
inline void build(int rt,int l,int r){
if(l==r){
a[rt]=wt[l];
if(a[rt]>mod)a[rt]%=mod;
return;
}
build(lson);
build(rson);
a[rt]=(a[rt<<1]+a[rt<<1|1])%mod;
}
inline void query(int rt,int l,int r,int L,int R){
if(L<=l&&r<=R){res+=a[rt];res%=mod;return;}
else{
if(laz[rt])pushdown(rt,len);
if(L<=mid)query(lson,L,R);
if(R>mid)query(rson,L,R);
}
}
inline void update(int rt,int l,int r,int L,int R,int k){
if(L<=l&&r<=R){
laz[rt]+=k;
a[rt]+=k*len;
}
else{
if(laz[rt])pushdown(rt,len);
if(L<=mid)update(lson,L,R,k);
if(R>mid)update(rson,L,R,k);
a[rt]=(a[rt<<1]+a[rt<<1|1])%mod;
}
}
//---------------------------------以上為線段樹
inline int qRange(int x,int y){
int ans=0;
while(top[x]!=top[y]){//當兩個點不在同一條鏈上
if(dep[top[x]]<dep[top[y]])swap(x,y);//把x點改為所在鏈頂端的深度更深的那個點
res=0;
query(1,1,n,id[top[x]],id[x]);//ans加上x點到x所在鏈頂端 這一段區間的點權和
ans+=res;
ans%=mod;//按題意取模
x=fa[top[x]];//把x跳到x所在鏈頂端的那個點的上面一個點
}
//直到兩個點處於一條鏈上
if(dep[x]>dep[y])swap(x,y);//把x點深度更深的那個點
res=0;
query(1,1,n,id[x],id[y]);//這時再加上此時兩個點的區間和即可
ans+=res;
return ans%mod;
}
inline void updRange(int x,int y,int k){//同上
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,1,n,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(1,1,n,id[x],id[y],k);
}
inline int qSon(int x){
res=0;
query(1,1,n,id[x],id[x]+siz[x]-1);//子樹區間右端點為id[x]+siz[x]-1
return res;
}
inline void updSon(int x,int k){//同上
update(1,1,n,id[x],id[x]+siz[x]-1,k);
}
inline void dfs1(int x,int f,int deep){//x當前節點,f父親,deep深度
dep[x]=deep;//標記每個點的深度
fa[x]=f;//標記每個點的父親
siz[x]=1;//標記每個非葉子節點的子樹大小
int maxson=-1;//記錄重兒子的兒子數
for(Rint i=beg[x];i;i=nex[i]){
int y=to[i];
if(y==f)continue;//若為父親則continue
dfs1(y,x,deep+1);//dfs其兒子
siz[x]+=siz[y];//把它的兒子數加到它身上
if(siz[y]>maxson)son[x]=y,maxson=siz[y];//標記每個非葉子節點的重兒子編號
}
}
inline void dfs2(int x,int topf){//x當前節點,topf當前鏈的最頂端的節點
id[x]=++cnt;//標記每個點的新編號
wt[cnt]=w[x];//把每個點的初始值賦到新編號上來
top[x]=topf;//這個點所在鏈的頂端
if(!son[x])return;//如果沒有兒子則返回
dfs2(son[x],topf);//按先處理重兒子,再處理輕兒子的順序遞歸處理
for(Rint i=beg[x];i;i=nex[i]){
int y=to[i];
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);//對於每一個輕兒子都有一條從它自己開始的鏈
}
}
int main(){
read(n);read(m);read(r);read(mod);
for(Rint i=1;i<=n;i++)read(w[i]);
for(Rint i=1;i<n;i++){
int a,b;
read(a);read(b);
add(a,b);add(b,a);
}
dfs1(r,0,1);
dfs2(r,r);
build(1,1,n);
while(m--){
int k,x,y,z;
read(k);
if(k==1){
read(x);read(y);read(z);
updRange(x,y,z);
}
else if(k==2){
read(x);read(y);
printf("%d\n",qRange(x,y));
}
else if(k==3){
read(x);read(y);
updSon(x,y);
}
else{
read(x);
printf("%d\n",qSon(x));
}
}
}
