【AtCoder3611】Tree MST(點分治,最小生成樹)
題面
AtCoder
洛谷
給定一棵\(n\)個節點的樹,現有有一張完全圖,兩點\(x,y\)之間的邊長為\(w[x]+w[y]+dis(x,y)\),其中\(dis\)表示樹上兩點的距離。
求完全圖的\(MST\)。
題解
首先連邊的這個式子可以直接轉換成樹上的兩點間的路徑,所以接下來只考慮\(dis(x,y)\)。
考慮\(Boruvka\)算法的執行過程,每次都會選擇到達一個點集最近的一個點,然后將他們連邊。
現在考慮模擬這個過程,那么在樹上我們欽定一個點作為根節點,考慮過根節點的路徑的連邊情況。
對於每個點我們要找到離他最近的點,因此顯然就是在其他子樹內找到一個距離根節點最小的點然后讓這個點向所有其他的點連邊,這樣子對於每個根節點我們都可以找到唯一的點,然后讓它向其他所有點連邊就好了。
顯然不用以所有點為根,只需要把當前根節點丟掉,把子樹再單獨處理就好了。
不難發現點分治就是這么一個過程,因此對於每個分治重心找到距離根節點距離最近的點。
這樣子一來點分治是\(O(nlogn)\),邊數是\(O(nlogn)\),最后再跑一遍克魯斯卡爾。
因此總的復雜度就是\(O(nlog^2n)\)。
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define MAX 200200
#define ll long long
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
struct Line{int v,next,w;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v,int w){e[cnt]=(Line){v,h[u],w};h[u]=cnt++;}
struct Edge{int u,v;ll w;}E[MAX*50];
bool operator<(Edge a,Edge b){return a.w<b.w;}
int Size,mx,rt,size[MAX];bool vis[MAX];
int n,m,W[MAX];
void Getroot(int u,int ff)
{
int ret=0;size[u]=1;
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(vis[v]||v==ff)continue;
Getroot(v,u);size[u]+=size[v];
ret=max(ret,size[v]);
}
ret=max(ret,Size-size[u]);
if(ret<mx)mx=ret,rt=u;
}
int P;ll Val;
void dfs(int u,int ff,ll dep)
{
if(dep+W[u]<Val)Val=dep+W[u],P=u;
for(int i=h[u];i;i=e[i].next)
if(!vis[e[i].v]&&e[i].v!=ff)
dfs(e[i].v,u,dep+e[i].w);
}
void Link(int u,int ff,ll dep)
{
E[++m]=(Edge){u,P,Val+dep+W[u]};
for(int i=h[u];i;i=e[i].next)
if(!vis[e[i].v]&&e[i].v!=ff)
Link(e[i].v,u,dep+e[i].w);
}
void Divide(int u)
{
vis[u]=true;
Val=1e18;P=0;dfs(u,0,0);Link(u,0,0);
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(vis[v])continue;
Size=mx=size[v];Getroot(v,u);
Divide(rt);
}
}
int f[MAX];ll ans;
int getf(int x){return x==f[x]?x:f[x]=getf(f[x]);}
int main()
{
n=read();
for(int i=1;i<=n;++i)W[i]=read();
for(int i=1;i<n;++i)
{
int u=read(),v=read(),w=read();
Add(u,v,w);Add(v,u,w);
}
Size=mx=n;Getroot(1,0);Divide(rt);
sort(&E[1],&E[m+1]);
for(int i=1;i<=n;++i)f[i]=i;
for(int i=1;i<=m;++i)
{
int u=getf(E[i].u),v=getf(E[i].v);
if(u==v)continue;
ans+=E[i].w;f[u]=v;
}
printf("%lld\n",ans);
return 0;
}