[2020-CCPC Changchun Onsite]-F. Strange Memory(dsu on tree)
題面:
題意:
給定一個含有\(\mathit n\)個節點的數,求下式的值。
\[\sum\limits_{i=1}^n\sum\limits_{j=i+1}^n [a_i \oplus a_j = a_{\operatorname{lca}(i, j)}] (i \oplus j). \]
思路:
觀察數據:\(1 \leq a_i \leq 10^6\),那么從根節點到葉子節點構成的路徑以及它們的子路徑不會對答案產生貢獻。
因為在同一個鏈上時,設\(dep_u>dep_v\),那么\(lca(u,v)=v\),因為\(a_u>0\),所以\(a_u\oplus a_v\not=a_v\)。
則可以推出答案值來源於一個節點作為根的子樹中根節點的不同兒子子樹之間的貢獻。
那么問題可以轉化為求出每一個子樹的lca為子樹根的子節點們對答案的貢獻總和。
子樹問題考慮到使用dsu on tree算法,進行輕重鏈剖分。
因為答案是點對的異或值總和,考慮到將其二進制拆分,按位選貢獻,開桶維護個數。
時間復雜度:\(O(n*log^2(n))\)
代碼:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
int cnt[maxn * 15][18][2];
vector<int>v[maxn];
int n;
int a[maxn];
ll sum = 0ll;
int son[maxn], siz[maxn];
int Son;
ll base[30];
void dfs(int x, int fa)
{
siz[x] = 1;
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa) {
continue;
}
dfs(to, x);
siz[x] += siz[to];
if (siz[to] > siz[son[x]]) {
son[x] = to;
}
}
}
void add(int x, int fa, int val, int num)
{
if (val == 0)
for (int i = 0; i < 18; ++i) {
sum += base[i] * cnt[a[x] ^ num][i][!((x >> i) & 1)];
}
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa || to == Son) {
continue;
}
add(to, x, val, num);
}
if (val != 0)
for (int i = 0; i < 18; ++i) {
cnt[a[x]][i][(x >> i) & 1] += val;
}
}
void dfs2(int x, int fa, int opt)
{
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa) {
continue;
}
if (to != son[x]) {
dfs2(to, x, 0);
}
}
if (son[x]) {
dfs2(son[x], x, 1);
Son = son[x];
}
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa || to == Son) {
continue;
}
add(to, x, 0, a[x]);
add(to, x, 1, a[x]);
}
for (int i = 0; i < 18; ++i) {
cnt[a[x]][i][(x >> i) & 1]++;
}
Son = 0;
if (!opt) {
for (int i = 0; i < 18; ++i) {
cnt[a[x]][i][(x >> i) & 1]--;
}
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa || to == Son) {
continue;
}
add(to, x, -1, a[x]);
}
}
}
int main()
{
base[0] = 1ll;
for (int i = 1; i <= 20; ++i) {
base[i] = base[i - 1] * 2ll;
}
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= n - 1; ++i) {
int x, y;
scanf("%d %d", &x, &y);
v[x].push_back(y);
v[y].push_back(x);
}
dfs(1, 0);
dfs2(1, 0, 0);
printf("%lld\n", sum );
return 0;
}