這種不怎么難寫的東西,我學得快忘得也快,也是給自己加深印象,同時留個自己(大概)能看懂的講解好復習……qwq
先說是什么
dsu on tree中的dsu就是Disjoint Set Union,雖然整個算法跟並茶幾(話說並茶幾名字好多啊……)沒有任何關系……硬要說就是借用了啟發式合並的思想吧……
這個算法是拿來解決樹上對子樹內答案的詢問的,當然它並不支持修改
它在暴力的基礎上,借助輕重鏈剖分的性質把復雜度降低到了\(O(n \log n)\)
大致過程
遍歷每一個節點,先遞歸解決輕兒子,完成后消除遞歸產生的影響
然后解決重兒子,但不消除影響
將輕兒子的答案合並上來
消除上一個過程中輕兒子產生的影響
拿一個例題來說:CF600E
題意:樹上每個節點有一個顏色,求每棵子樹中出現次數最多的顏色(可能有多個)之和
首先是輕重鏈剖分,處理出每個節點的重兒子
void dfs(int u, int fa) {
size[u] = 1;
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa) continue;
dfs(v, u);
size[u] += size[v];
if (!heavy[u] || size[v] > size[heavy[u]]) heavy[u] = v;
}
}
然后遍歷節點,按上面的流程來(具體看注釋)
void update(int u, int fa, int val, const int &hvy) {//val為1暴力合並統計輕兒子的答案,為-1清除對cnt的影響
cnt[col[u]] += val;
if (val > 0 && cnt[col[u]] >= max_cnt) {
if (cnt[col[u]] > max_cnt) sum = 0, max_cnt = cnt[col[u]];
sum += (LL)col[u];
}
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa || v == hvy) continue;
update(v, u, val, hvy);
}
}
void dfs(int u, int fa, int opt) {//opt為0表示需要清除掉u的影響,為1表示不需要
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa || v == heavy[u]) continue;
dfs(v, u, 0);//遞歸解決輕兒子,完成后清除影響
}
if (heavy[u]) dfs(heavy[u], u, 1);//解決重兒子,保留影響
update(u, fa, 1, heavy[u]);//合並輕兒子的答案
ans[u] = sum;
if (!opt) update(u, fa, -1, 0), sum = 0, max_cnt = 0;//清除影響
}
最后是完整代碼:
PS.怎么網上的博客代碼個個不一樣啊……,蒟蒻我懵逼了好長時間才看明白qwq
#include <cstdio>
#include <cstring>
#include <iostream>
#define MAXN 100005
typedef long long LL;
struct Graph {
struct Edge {
int v, next;
Edge(int _v = 0, int _n = 0):v(_v), next(_n) {}
} edge[MAXN << 1];
int head[MAXN], cnt;
void init() { memset(head, -1, sizeof head); cnt = 0; }
void add_edge(int u, int v) { edge[cnt] = Edge(v, head[u]); head[u] = cnt++; }
void insert(int u, int v) { add_edge(u, v); add_edge(v, u); }
Edge & operator [](int x) { return edge[x]; }
} G;
int col[MAXN], size[MAXN], heavy[MAXN], val[MAXN], cnt[MAXN], N;
LL sum, max_cnt, ans[MAXN];
void dfs(int, int);
void dfs(int, int, int);
void update(int, int, int, const int &);
int main() {
G.init();
scanf("%d", &N);
for (int i = 1; i <= N; ++i) scanf("%d", col + i);
for (int i = 1; i < N; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v);
}
dfs(1, 0);
dfs(1, 0, 0);
for (int i = 1; i <= N; ++i) printf("%I64d ", ans[i]);
return 0;
}
void dfs(int u, int fa) {
size[u] = 1;
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa) continue;
dfs(v, u);
size[u] += size[v];
if (!heavy[u] || size[v] > size[heavy[u]]) heavy[u] = v;
}
}
void update(int u, int fa, int val, const int &hvy) {
cnt[col[u]] += val;
if (val > 0 && cnt[col[u]] >= max_cnt) {
if (cnt[col[u]] > max_cnt) sum = 0, max_cnt = cnt[col[u]];
sum += (LL)col[u];
}
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa || v == hvy) continue;
update(v, u, val, hvy);
}
}
void dfs(int u, int fa, int opt) {
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa || v == heavy[u]) continue;
dfs(v, u, 0);
}
if (heavy[u]) dfs(heavy[u], u, 1);
update(u, fa, 1, heavy[u]);
ans[u] = sum;
if (!opt) update(u, fa, -1, 0), sum = 0, max_cnt = 0;
}
//Rhein_E
最后是復雜度證明
輕重鏈剖分保證了每個節點到根的路徑上輕邊條數不超過\(\log n\)
每個節點被訪問一次,要么是它的祖先節點暴力統計輕兒子/消除影響的時候,要么是它自己統計答案的時候
前者\(O(\log n)\)次,后者\(1\)次
所以總復雜度是\(O(n \log n)\)的