dsu on tree
簡介
我也不清楚dsu是什么的英文縮寫...
好吧是Disjoint Set Union 並查集2333
就像是樹上的啟發式合並
用到了\(heavy-light\ decomposition\)樹鏈剖分
把輕邊子樹的信息合並到重鏈上的點里
因為每次都是先dfs輕兒子再dfs重兒子,只有重兒子子樹的貢獻保留,所以可以保證dfs到每顆子樹時當前全局維護的信息不會有別的子樹里的,和莫隊很像
算法過程
find the BigChild of each vertex
dfs(u, fa, keep)
dfs(LightChild, u, 0)
dfs(BigChild, u, 1), big[BigChild] = 1
update(u, fa, 1) //calculate the contribution of u's LightChild's SubTree
update the ans of u
big[BigChild] = 0
if keep == 0
update(u, fa, -1) //remove the contributino of u's SubTree
update(u, fa, val)
calculate u's information
update(v : (u, v) and !big[v], u, val)
先遞歸計算輕兒子子樹,遞歸結束時消除他們的貢獻
再遞歸計算重兒子子樹,保留他的貢獻
再計算當前子樹中所有輕子樹的貢獻
更新答案
如果當前子樹是父節點的輕子樹,消除當前子樹的貢獻
復雜度分析
顯然只有遇到輕邊才會把自己的信息合並到父節點
樹鏈剖分后每個點到根的路徑上有\(logn\)條輕邊和\(lgon\)條重鏈
一個點的信息只會向上合並\(logn\)次
如果一個點的信息修改是\(O(1)\)的,那么總復雜度就是\(O(nlogn)\)
從這里我們可以發現和對dfs序使用莫隊有異曲同工之妙,莫隊也要求修改的復雜度很低
應用
- 優秀的dfs序莫隊替代品,復雜度\(\sqrt{n} \rightarrow logn\)
- 結合點分治的思想可以做一些有根樹上的路徑統計問題
模板題
CF600E. Lomsat gelral
題意:詢問每顆子樹中出現次數最多的顏色們編號和
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define pii pair<int, ll>
#define MP make_pair
#define fir first
#define sec second
const int N=1e5+5;
int read(){
char c=getchar();int x=0,f=1;
while(c<'0'||c>'9'){if(c=='-')f=-1; c=getchar();}
while(c>='0'&&c<='9'){x=x*10+c-'0'; c=getchar();}
return x*f;
}
int n, a[N];
struct edge{int v, ne;}e[N<<1];
int cnt, h[N];
inline void ins(int u, int v) {
e[++cnt]=(edge){v, h[u]}; h[u]=cnt;
e[++cnt]=(edge){u, h[v]}; h[v]=cnt;
}
int size[N], mx[N], big[N];
void dfs(int u, int fa) {
size[u]=1;
for(int i=h[u];i;i=e[i].ne)
if(e[i].v != fa) {
dfs(e[i].v, u);
size[u] += size[e[i].v];
if(size[e[i].v] > size[mx[u]]) mx[u] = e[i].v;
}
}
int cou[N], Max; ll ans[N];
pii f[N];
void update(int u, int fa, int val) {
int &c = cou[a[u]];
f[c].fir --; f[c].sec -= a[u];
c += val;
f[c].fir ++; f[c].sec += a[u];
if(val==1) Max = max(Max, c);
else if(!f[Max].fir) Max--;
for(int i=h[u];i;i=e[i].ne)
if(e[i].v != fa && !big[e[i].v]) update(e[i].v, u, val);
}
void dfs(int u, int fa, int keep) {
for(int i=h[u];i;i=e[i].ne)
if(e[i].v != fa && e[i].v != mx[u]) dfs(e[i].v, u, 0);
if(mx[u]) dfs(mx[u], u, 1), big[mx[u]] = 1;
update(u, fa, 1);
ans[u] = f[Max].sec;
big[mx[u]] = 0;
if(!keep) update(u, fa, -1);
}
int main() {
//freopen("in","r",stdin);
n=read();
for(int i=1; i<=n; i++) a[i]=read();
for(int i=1; i<n; i++) ins(read(), read());
dfs(1, 0);
dfs(1, 0, 1);
for(int i=1; i<=n; i++) printf("%I64d ",ans[i]);
}