題目描述
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
輸入
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
The last test case is followed by two zeros.
輸出
For each test case output the answer on a single line.
樣例輸入
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
樣例輸出
8
題目大意
多組測試數據,每次輸入n、m,和一棵n個點的有邊權的樹,問你滿足x到y距離小於等於m的無序點對(x,y)的個數是多少。
題解
樹的點分治模板題,第一次寫
考慮到路徑只有兩種情況,一是經過根節點,二是不經過根節點。
如果不經過根節點,那么一定經過最小公共子樹的根節點,可以轉化為問題一的子問題。
於是考慮怎么遞歸解決問題一。
對於根節點進行一次dfs,求出deep,並將其從小到大排序。
避免重復,只需要求出其中deep[x]≤deep[y]且deep[x]+deep[y]≤m的個數。
用i表示左指針,j表示右指針,i從左向右遍歷。
如果deep[i]+deep[j]≤m,則點對(i,t)(i<t≤j)都符合題意,將j-i加入答案中,並且i++;否則j--。
然而這樣還會重復計算在同一棵子樹中的點對,所以再進行下一步dfs之前需要減去重復部分。
但是這樣做會TLE。為什么?因為樹可能會退化,導致選擇鏈頭時時間復雜度極大。
於是每次不能固定選擇root,而是以重心作為root去處理,這樣能保證時間復雜度再O(nlog2n)以下。
#include <cstdio> #include <cstring> #include <algorithm> #define N 10010 using namespace std; int m , head[N] , to[N << 1] , len[N << 1] , next[N << 1] , cnt , si[N] , deep[N] , root , vis[N] , f[N] , sn , d[N] , tot , ans; void add(int x , int y , int z) { to[++cnt] = y , len[cnt] = z , next[cnt] = head[x] , head[x] = cnt; } void getroot(int x , int fa) { f[x] = 0 , si[x] = 1; int i; for(i = head[x] ; i ; i = next[i]) if(to[i] != fa && !vis[to[i]]) getroot(to[i] , x) , si[x] += si[to[i]] , f[x] = max(f[x] , si[to[i]]); f[x] = max(f[x] , sn - si[x]); if(f[root] > f[x]) root = x; } void getdeep(int x , int fa) { d[++tot] = deep[x]; int i; for(i = head[x] ; i ; i = next[i]) if(to[i] != fa && !vis[to[i]]) deep[to[i]] = deep[x] + len[i] , getdeep(to[i] , x); } int calc(int x) { tot = 0 , getdeep(x , 0) , sort(d + 1 , d + tot + 1); int i = 1 , j = tot , sum = 0; while(i < j) { if(d[i] + d[j] <= m) sum += j - i , i ++ ; else j -- ; } return sum; } void dfs(int x) { deep[x] = 0 , vis[x] = 1 , ans += calc(x); int i; for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]]) deep[to[i]] = len[i] , ans -= calc(to[i]) , sn = si[to[i]] , root = 0 , getroot(to[i] , 0) , dfs(root); } int main() { int n , i , x , y , z; while(scanf("%d%d" , &n , &m) && (n || m)) { memset(head , 0 , sizeof(head)); memset(vis , 0 , sizeof(vis)); cnt = 0 , ans = 0; for(i = 1 ; i < n ; i ++ ) scanf("%d%d%d" , &x , &y , &z) , add(x , y , z) , add(y , x , z); f[0] = 0x7fffffff , sn = n; root = 0 , getroot(1 , 0) , dfs(root); printf("%d\n" , ans); } return 0; }