點分治詳解
一.概念
是處理樹上路徑的一個極好的方法。如果你需要大規模的處理一些樹上路徑的問題時,點分治是一個不錯的選擇。
二.具體思路
大多數同學的暴力做法都是對於每一個點對(u,v) 進行dfs來求解。但其實利用分治這一種算法,可以大大減少搜索的時間復雜度。
對於一個序列上的區間和等操作,我們可以使用分治來將原問題分解成幾個子問題來求解,之后在一一合並答案。而在樹上我們也是可以進行這一種操作的。可是樹上的每一個子樹的節點數是不確定的,不能單單的取中點(你告訴我怎么取),或直接取一號子樹。(分治的點的錯誤選擇會導致時間復雜度十分不穩定)。
如下圖所示,如果你取了第一個點的話,那么時間復雜度會變\(O(n)\),但如果我們取的點是3的話,那么時間復雜度就會是 \(O(logn)\)
所以,我們要引入一個概念 —— 樹的重心
定義:找到一個點,其所有的子樹中最大的子樹節點數最少,那么這個點就是這棵樹的重心,刪去重心后,生成的多棵樹盡可能平衡
由定義可知,當我們選擇樹的重心為分支點時,是最優的(我有個絕妙的證明只是這里寫不下)
好了,求出了樹的重心之后我們就可以來分治了!!
先現給出求重心的代碼,便於讀者依次理解
void find(int x,int fa)
{
size[x] = 1; mx[x] = 0;
for (int i = head[x]; i ; i = edges[i].net)
{
edge v = edges[i];
if(v.to == fa||vis[v.to] ) continue;//vis是之后分治是要用到的
find(v.to,x);
size[x] += size[v.to];
chkmax(mx[x],size[v.to]);
}
chkmax(mx[x],S-size[x]);//S為樹的大小,記住x的上面要算入的
if(mx[x] < mx[root])
{
root = x;
}
}
現在開始我們點分治中最重要的部分了 —— 分治
分治不太好講,我們從代碼開始分析
void Divid(int x)
{
ans+=solve(x,0);
vis[x] = 1;
for (int i = head[x];i;i = edges[i].net)
{
edge v = edges[i];
if(vis[v.to]) continue;
ans-=solve(v.to,edges[i].cost);
S = size[v.to]; root = 0;
find(v.to,x);
Divid(root);
}
}
- ans += solve(x,0); 這一句的作用是將答案加上經過x的路徑答案。 而這一個0是為了解決掉一些,有重復計算的結果;(看不懂先假裝沒有這個0)
- ans -= solve(v.to,edges[i].cost); 這一句是將在既經過x這個點,又經過v.to這一個點的路徑來去重。因為像這種路徑會在solve(x,0)和solve(v.to,0)中都計算一次。而題目是要求路徑的長度,所以在容斥時要初始化這條邊的長度。所以,現在有沒有理解這個0和edges[i].cost?
- S = size[v.to]; 現在我們要分治v.to的這一顆子樹,So,又將求重心的樹的大小改為size[v.to];
到此為止,點分治就在這里講完了,solve函數是看題目的,有能力的同學可以切一切這兩道題(這兩道題會在下面進行講解)。luogu模板題 和聰聰可可.
三.例題分析
1.luogu模板題
題面在上面。
因為題目是要求路徑長為k的路徑條數,所以solve函數返回的是過x節點的長度為k的路徑。
而這路徑長度是可以用 \(O(n)\) 的方法求出
// luogu-judger-enable-o2
#include<bits/stdc++.h>
template <class T>
inline void read(T &a)
{
T s = 0, w = 1;
char c = getchar();
while(c < '0' || c > '9')
{
if(c == '-') w = -1;
c = getchar();
}
while(c >= '0' && c <= '9')
{
s = (s << 1) + (s << 3) + (c ^ 48);
c = getchar();
}
a = s*w;
}
template<class T> void chkmax(T &a, T b) {a > b ? (a = a) : (a = b);}
template<class T> void chkmin(T &a, T b) {a > b ? (a = b) : (a = a);}
template<class T> T min(T a, T b) {return a > b ? b : a;}
template<class T> T max(T a, T b) {return a < b ? b : a;}
int n,m;
int S;
int size[10101];
struct edge{
int from,to,cost,net;
edge(int f = 0, int t = 0, int cost = 0, int nex = 0)
{
from = f;
to = t;
this->cost = cost;
net = nex;
}
}edges[1010101];
int tot,head[101001],mx[101011],minn =0x3f3f3f3f,root;
int vis[1010110];
void add(int x, int y, int z)
{
edges[++tot] = edge(x,y,z,head[x]);
head[x] = tot;
}
void find(int x,int fa)
{
size[x] = 1;mx[x] = 0;
for (int i = head[x];i; i =edges[i].net)
{
edge v = edges[i];
if(v.to == fa || vis[v.to]) continue;
find(v.to,x);
size[x] += size[v.to];
chkmax(mx[x],size[v.to]);
}
chkmax(mx[x], S - size[x]);
if(mx[x] < mx[root])
{
root = x;
}
}
int que[1010110],ans[102210101];
int dis[1010101],hhd,a[10101101];
void get_dis(int x, int len, int fa)
{
dis[++hhd] = a[x];
for (int i = head[x]; i; i = edges[i].net)
{
edge v = edges[i];
if(vis[v.to]||v.to == fa) continue;
a[v.to] = len + edges[i].cost;
get_dis(v.to,len + edges[i].cost,x);
}
}
void solve(int s, int len, int w)
{
hhd = 0;
a[s] = len;
get_dis(s,len,0);
for (int i1 = 1; i1 <= hhd; i1++)
for (int i2 = 1; i2 <= hhd; i2++)
{
if(i1 != i2)
{
ans[dis[i1] + dis[i2]] += w;
}
}
}
void Divide(int x)
{
solve(x,0,1);
vis[x] = 1;
for (int i = head[x]; i; i = edges[i].net)
{
edge v = edges[i];
if(vis[v.to]) continue;
solve(v.to,edges[i].cost,-1);
S = size[x];root = 0; mx[0] = n;
find(v.to,x);
Divide(root);
}
}
int main()
{
read(n); read(m);
for (int i = 1; i < n; i++)
{
int x,y,z;
read(x); read(y); read(z);
add(x,y,z);
add(y,x,z);
}
S = n;mx[0] = n;root = 0;
// minn = 0x3f3f3f3f;
find(1,0);
// printf("%d\n",mx[root]);
Devede(root);
for (int i = 1; i <= m; i++)
{
int k;
read(k);
printf("%s\n",(ans[k]) ? "AYE" : "NAY");
//printf("%d\n",ans[k]);
}
return 0;
}
2.聰聰可可
這道題是來求長度被3整除的路徑條數,但處理方法跟上一條不太一樣。
我們可以設p[0],p[1],p[2]為除3余數為0,1,2的 路徑條數。顯然答案為\(p_0^2\) + \(p_1 * p_2 * 2\)
// luogu-judger-enable-o2
// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include<bits/stdc++.h>
int gcd(int x, int y)
{
if(y == 0) return x;
return gcd(y,x%y);
}
template<class T>
inline void read(T &a)
{
T s = 0,w = 1;
char c = getchar();
while(c < '0' || c > '9')
{
if(c == '-') w = -1;
c = getchar();
}
while(c >= '0' && c <= '9')
{
s = (s << 1) + (s << 3) + (c ^ 48);
c = getchar();
}
a = s*w;
}
template<class T> void chkmax(T &a, T b){a > b? (a = a) : (a = b);}
template<class T> void chkmin(T &a, T b){a > b ? (a = b):(a = a);}
int n;
struct edge{
int from, to,cost,net;
edge(int f = 0, int t = 0, int c = 0, int n = 0)
{
from = f;
to = t;
cost = c;
net = n;
}
}edges[2010101];
static int head[20010],tot;
void add(int x, int y, int z)
{
edges[++tot] = edge(x,y,z,head[x]);
head[x] = tot;
}
static int vis[20010],size[20010],mx[20010],root,S;
void find(int x,int fa)
{
size[x] = 1; mx[x] = 0;
for (int i = head[x]; i ; i = edges[i].net)
{
edge v = edges[i];
if(v.to == fa||vis[v.to] ) continue;
find(v.to,x);
size[x] += size[v.to];
chkmax(mx[x],size[v.to]);
}
chkmax(mx[x],S-size[x]);
if(mx[x] < mx[root])
{
root = x;
}
}
int dis[20010],a[20010],cnt;
int ans,p[3];
void get_dis(int x, int fa)
{
// dis[++cnt] = a[x];
p[a[x]%3]++;
for (int i = head[x] ;i; i = edges[i].net)
{
edge v = edges[i];
if(v.to == fa ||vis[v.to] ) continue;
a[v.to] = a[x]+v.cost;
get_dis(v.to,x);
}
}
int solve(int x, int len)
{
a[x] = len;
//cnt = 0;
p[0] = p[1] = p[2] = 0;
get_dis(x,0);
return (p[0]*p[0] + 2 * p[1] * p[2]);
}
void Deved(int x)
{
ans+=solve(x,0);
vis[x] = 1;
for (int i = head[x];i;i = edges[i].net)
{
edge v = edges[i];
if(vis[v.to]) continue;
ans-=solve(v.to,edges[i].cost);
S = size[v.to]; root = 0;
find(v.to,x);
Deved(root);
}
}
int main()
{
//freopen("xx.in","r",stdin);
//freopen("xx.out","w",stdout);
read(n);
for (register int i = 1; i < n; i++)
{
int x,y,z;
read(x); read(y); read(z);
z%=3;
add(x,y,z);
add(y,x,z);
}
S = n;root = 0; mx[0] = n+1;
find(1,0);
Deved(root);
int pp = gcd(ans,n*n);
printf("%lld/%lld\n",ans/pp,n*n/pp);
// std::cerr<<std::clock()<<std::endl;
return 0;
}