算法介紹
點分治,顧名思義,是一種對點進行分治的數據結構。(樹上的點)
多用於在樹上進行有限制的路徑計數。
比如:求樹上長度小於$ k$ 的簡單路徑條數。\((n \leq 10000)\)
直接做肯定是補星的。所以就需要點分治這種東西了。
需要統計的路徑肯定有這么兩種:
- 1.經過根節點$ root $的路徑
- 2.不經過根節點\(root\)的路徑
顯然第二種路徑對於某個節點\(u\),屬於第一種路徑。所以分治解決即可。
我們來考慮第一種情況如何解決。
處理出一個數組\(d\),表示從當前根節點\(u\),到各個子節點的距離。
那么我們要統計的顯然就是\(d[u]+d[v]\leq k\)的路徑\((u,v)\)的個數。
這個東西可以通過在dfs求這個數組時順便把所有的\(d\)值記錄下來,排序之后讓他們具有單調性。
然后雙指針掃一下就好(合法狀態就是\(d[l]+d[r]\leq k\))那么當指針在\(l\)時,對答案的貢獻就是\(r-l\)(不能重復選自己,所以不+1)
然后現在考慮一種情況。當\(u,v\)都在當前根節點的同一個子樹里面。這樣子的話,路徑\((u,v)\)如果經過根節點就不是一條簡單路徑了(重邊)。如何解決呢?
容斥的思想!
對於每個子樹,分別處理它其中的子節點的d值,給答案減去就行了!
代碼大概就長這個樣子
void dfs(int u) {
vis[u] = 1;
ans += solve(u, 0); //所有情況
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to]) continue;
int v = e[i].to;
ans -= solve(v, e[i].v); //減掉不合法情況
//下面是找重心的代碼,后面會解釋為什么要找重心
now_sz = inf, root = 0; sz = siz[v];
find_root(v, 0);
dfs(root);
}
}
先不管為什么要找重心。我們總結一下算法流程:
- 1.找一個根節點root
- 2.對root計算出d數組並計算答案
- 3.把root刪了,對root的各個子樹執行流程1,2
復雜度是多少呢?粗略估計一下是\(O(Tnlogn) \),\(T\)是樹的層數。(這里有個\(log\)是因為用了排序)
顯然我們要讓這個樹優美一點,身材圓潤一點,不能瘦成一條鏈,不然復雜度就變成\(O(n^2logn) \)了。
那這個根節點怎么找呢?樹的重心!
將重心當做根節點,可以保證樹是\(log\)層的!
那么復雜度就是$O(nlog^2n) \(了!(如果不使用排序的話(比如一些題是用到的桶),那么復雜度是\)O(nlogn)$)
還有就是關於點分治這里的重心有兩種找法。一種就是上面那樣的,另外一種就是改了一句
sz = siz[v];
->sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
實際上第二種才是對的,因為v可能在上次處理siz數組時是u的父親(這是一棵無根樹!)
但是復雜度並不會退化qwq,有神仙證明了。鏈接
例題:
POJ1741 tree
真正的模板題。就是我上面提到的那個問題。
直接點分一下就好了。每次將距離排序一下,然后雙指針掃一掃,每次合法答案就是r-l,容斥一下將不合法情況減去即可。注意找重心不要寫錯不然復雜度就炸了。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010
inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}
int n, k, d[N], cnt, head[N], ans;
int vis[N], siz[N];
struct edge {
int to, nxt, v;
}e[N<<1];
void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
head[u] = cnt;
}
int now_sz = inf, root = 0, sz;
void find_root(int u, int fa) {
siz[u] = 1;
int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to] || e[i].to == fa) continue;
int v = e[i].to;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}
int a[N], tot;
void get_dis(int u, int fa) {
a[++tot] = d[u];
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to] || e[i].to == fa) continue;
int v = e[i].to;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}
int solve(int u, int dis) {
d[u] = dis; tot = 0;
get_dis(u, u);
sort(a + 1, a + tot + 1);
int l = 1, r = tot, res = 0;
for(; l < r; ++l) {
while(l < r && a[l] + a[r] > k) --r;
if(l < r) res += r - l;
}
return res;
}
void dfs(int u) {
vis[u] = 1;
ans += solve(u, 0);
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to]) continue;
int v = e[i].to;
ans -= solve(v, e[i].v);
now_sz = inf, root = 0; sz = siz[v];
find_root(v, 0);
dfs(root);
}
}
int main() {
while(~scanf("%d%d", &n, &k) && n && k) {
ans = 0; cnt = 0;
memset(head, 0, sizeof(head));
memset(vis, 0, sizeof(vis));
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
dfs(1);
printf("%d\n", ans);
}
}
BZOJ2152: 聰聰可可
求倍數為3的路徑數。
考慮\(mod\ 3\)意義下的路徑,為0顯然可以互相拼起來,貢獻是\(sum[0]^2\)。1和2可以互相拼,而且起點終點互換,所以貢獻是\(sum[1]*sum[2]*2\),點分治計算這兩個即可。總方案數是\(n^2\),所以答案就是\(\frac{sum}{n^2}\)
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010
inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}
int n, k, d[N], cnt, head[N], ans;
int vis[N], siz[N], sum[3];
struct edge {
int to, nxt, v;
}e[N<<1];
void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
head[u] = cnt;
}
int now_siz, sz, root;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_siz) now_siz = res, root = u;
}
void get_dis(int u, int fa) {
sum[d[u]%3]++;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v] || v == fa) continue;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}
int solve(int u, int dis) {
d[u] = dis; sum[0] = sum[1] = sum[2] = 0;
get_dis(u, u);
return sum[0] * sum[0] + sum[1] * sum[2] * 2;
}
void dfs(int u) {
ans += solve(u, 0);
vis[u] = 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
ans -= solve(v, e[i].v);
now_siz = inf; sz = siz[v]; root = 0;
find_root(v, u);
dfs(root);
}
}
int main() {
in(n);
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
now_siz = inf; root = 0; sz = n;
find_root(1, 1);
dfs(root);
int now = n * n, g = __gcd(now, ans);
printf("%d/%d\n", ans / g, now / g);
}
LuoguP3806 【模板】點分治1
注意這題數據很水...
求長度為k的路徑是否存在。多次詢問(詢問數\(\leq 100\))
這題效率有點奇怪...
自己估算了一下是\(O(mnlog^2n)\)。
對長度正好k的話,其實用個桶標記就好了,實際上和小於k沒多大區別的。
考慮先將詢問離線,然后在點分治過程中對所有答案進行判定。處理出d[]表示到節點i到當前根的距離。那么照例是拼路徑,但是現在不是求方案總數而是求有沒有這個方案,看起來不能容斥了。但是實際上可以的:考慮先對根u solve一遍,給所有詢問加上這次的結果,然后對每個子節點計算一遍,給所有詢問減掉這次的結果就好了。
具體的話看看代碼吧
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010
#define lim 10000000
inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}
int top, n, m, d[N], cnt, head[N], ans[110];
int vis[N], siz[N], q[110], st[N], s[10000010];
struct edge {
int to, nxt, v;
}e[N<<1];
void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
head[u] = cnt;
}
int now_sz = inf, root, sz;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
res = max(res, siz[v]);
siz[u] += siz[v];
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}
void get_dis(int u, int fa) {
st[++top] = d[u];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}
void solve(int u, int dis, int op) {
top = 0; d[u] = dis; get_dis(u, 0);
for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]++;
for(int i = 1; i <= m; ++i) {
for(int j = 1; j <= top; ++j) if(q[i] >= st[j]) ans[i] += s[q[i] - st[j]] * op;
}
for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]--;
}
void dfs(int u) {
vis[u] = 1;
solve(u, 0, 1);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
top = 0; d[v] = e[i].v;
solve(v, e[i].v, -1);
now_sz = inf, root = 0, sz = siz[v];
find_root(v, u);
dfs(root);
}
}
int main() {
in(n), in(m);
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
for(int i = 1; i <= m; ++i) in(q[i]);
sz = n; now_sz = inf; root = 0;
find_root(1, 1); dfs(root);
for(int i = 1; i <= m; ++i) puts(ans[i] ? "AYE" : "NAY");
}
CF161D Distance in Tree
求長度等於k的路徑數...就很煩....這種一般都要分類討論
需要分類討論一下,同樣是套路點分然后開個桶,然后分\(k-v[i]=v[i]\)和不等兩種情況,顯然相等的話答案就是\(cnt[v[i]]*(cnt[v[i]]-1)/2\).不相等的話用乘法原理考慮一下,\(cnt[v[i]]*cnt[k-v[i]]\),注意每次統計完之后就要把cnt清空。
#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define il inline
namespace io {
#define in(a) a = read()
#define out(a) write(a)
#define outn(a) out(a), putchar('\n')
#define I_int ll
inline I_int read() {
I_int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}
char F[200];
inline void write(I_int x) {
if (x == 0) return (void) (putchar('0'));
I_int tmp = x > 0 ? x : -x;
if (x < 0) putchar('-');
int cnt = 0;
while (tmp > 0) {
F[cnt++] = tmp % 10 + '0';
tmp /= 10;
}
while (cnt > 0) putchar(F[--cnt]);
}
#undef I_int
}
using namespace io;
using namespace std;
#define N 100010
int n, k;
int cnt, head[N], vis[N], d[N];
struct edge {
int to, nxt;
}e[N<<1];
void ins(int u, int v) {
e[++cnt] = (edge) {v, head[u]};
head[u] = cnt;
}
int siz[N], now_sz = inf, root, sz;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}
int top, st[N], s[N];
void get_dis(int u, int fa) {
st[++top] = d[u]; if(d[u] <= k) ++s[d[u]];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
d[v] = d[u] + 1;
get_dis(v, u);
}
}
ll solve(int u, int dis) {
d[u] = dis; top = 0; get_dis(u, 0);
ll ans = 0;
for(int i = 1; i <= top; ++i)
if(st[i] <= k) {
if(st[i] * 2 == k) ans += 1ll * s[st[i]] * (s[st[i]] - 1) / 2ll;
else ans += 1ll * s[k - st[i]] * s[st[i]];
s[st[i]] = s[k - st[i]] = 0;
}
return ans;
}
ll ans = 0;
void dfs(int u) {
vis[u] = 1; ans += solve(u, 0);
int totsiz = sz;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
ans -= solve(v, 1);
sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
now_sz = inf; root = 0;
find_root(v, 0);
dfs(root);
}
}
int main() {
in(n), in(k);
for(int i = 1; i < n; ++i) {
int u = read(), v = read();
ins(u, v), ins(v, u);
}
now_sz = inf; sz = n; root = inf;
find_root(1, 0);
dfs(root);
outn(ans);
}