// 此博文為遷移而來,寫於2015年7月11日,不代表本人現在的觀點與看法。原始地址:http://blog.sina.com.cn/s/blog_6022c4720102w69l.html
UPDATE(20180824):進行多處修正,並添加多處注釋,代碼重寫。感謝評論區的建議。
一、前言
樹鏈剖分,一個高大上的名字。樹鏈,即樹上的路徑,現在我們的任務是所謂的剖分。所以我們可以看出,樹鏈剖分並不是一種單獨的數據結構,不像堆,線段樹等等,而是直接在一棵普通的樹上處理,然而單是這一課樹是並沒有什么卵用的。今天先講一個相對比較簡單的情況——用一棵線段樹維護主樹每條邊的權值。
二、概念
首先引入幾個概念。
> 重兒子:設非葉子節點u存在若干個子節點,每個子節點有若干個子節點,重兒子即為其子節點中子節點最多的節點。
> 重邊:非葉子節點與其重兒子所連的邊。
> 重鏈:由連續的重邊組成的一條鏈。
那么這些東西有什么用?先來看一道例題(也就是樹鏈剖分與線段樹維護的經典例題)。
三、例題
四、過程
題目為單點修改,區間詢問。單點修改不用多提,重點在於,我們如何把從節點u到節點v這條路徑上的節點求出最大值以及權值和呢?先考慮一種暴力的算法——跑LCA。我們根據u和v的深度找到公共父親節點,然后在從子節點向上跳的時候,得到最大值或是權值和(如果是修改操作,其實也是同理)。然而這終究是暴力。那么,重鏈在這道題中的作用就凸顯出來了——為了你在跑LCA的時候往上跳得更快。
根據最開始對重鏈等概念的描述,我們來看一張圖:
第一步,先求出每一個節點的重兒子,以及當前節點所在重鏈的頂端(如果當前節點是沒有重邊相連或者本身就是頂端,則就是其本身)。
第二步,根據重兒子,我們將每個節點對應的邊標號(由於這是一棵樹,則每一個節點與其父節點之間的邊有且僅有一條,我們稱之為節點對應的邊)。編號時,優先為其重兒子編號,直到到了葉子節點,再回溯上去為其他的兒子編號。如圖所示,最開始我們記根節點的重邊為1,然后一路編下去直至14號節點,回溯到4號節點,將4號節點與另一個子節點的邊標號為5,以此類推。
這樣,我們的目的就開始顯現了!一旦重鏈存在兩條及以上的重邊,其編號在線段樹中一定是連續的。如圖的1-2-3-4與10-11。
第三步,跑LCA。由於已經求出了每條重鏈的頂端,每次我們跑LCA的時候,若當前節點不是重鏈頂端,則可以直接跳到頂端——同時,因為他們在線段樹的編號是連續的,所以可以很方便的進行求值或者是修改,這一點只要會線段樹的就很好理解了!
舉個例子,如果我們需要求出11號和10號點的路徑上的權值之和,設初始狀態x1=11,x2=10,步驟為:
1、11的頂端為2,修改線段樹中的10-11,同時x1=2;這時,dep[x1]=2,dep[x2]=3;
2、10沒有重邊相連,頂端為自身,故向上找其父節點,修改線段樹中的4;發現父節點所在重鏈頂端為1,則修改線段樹中的1,同時x2=4;這時,dep[x1]=2,dep[x2]=1。(這里有一個小小的優化,即便這條鏈上是一條重邊,一條輕邊,也可以選擇一次性向上跳完)
3、2沒有重邊相連,頂端為自身,故向上找其父節點,修改線段樹中的9。此時,top[x1]=top[x2],且x1=x2,循環結束。
五、代碼
盡管個人認為整個過程已經描述的較為清楚,但代碼實現起來依舊有很多需要注意的細節,原因在於樹鏈剖分涉及面廣,先進行的兩遍DFS再加上后面的線段樹操作,代碼長,容易碼錯。這里對代碼進行一些提示:
1、geth函數為第一次DFS,作用在於求出每一個節點的重兒子及其在樹中的深度與其父節點;
2、mark函數為對每一條邊進行標號,優先重邊,同時維護好每一個節點與其對應邊的關系;
3、qmax/qsum:本質為跑LCA,在深度不等的情況下,每次對深度較大的點向上找祖先,如果找到重鏈,則直接利用線段樹維護的數據加快速度。
1 #include <cstdio> 2 #include <algorithm> 3 using namespace std; 4 5 #define MAXN 30005 6 #define INF 0x3f3f3f3f 7 8 int n, q, u, v, o, w[MAXN], h[MAXN]; 9 int f[MAXN], d[MAXN], tot[MAXN], hs[MAXN], top[MAXN], num[MAXN], lik[MAXN], now; 10 char ch[12]; 11 12 struct Tree { 13 int m, s; 14 } t[MAXN << 2]; 15 16 struct Edge { 17 int v, next; 18 } e[MAXN << 1]; 19 20 void add(int u, int v) { 21 o++, e[o] = (Edge) {v, h[u]}, h[u] = o; 22 o++, e[o] = (Edge) {u, h[v]}, h[v] = o; 23 } 24 25 int geth(int o, int of, int od) { 26 int oh = -1; 27 f[o] = of, d[o] = od; 28 for (int x = h[o]; x; x = e[x].next) { 29 int v = e[x].v; 30 if (v == of) continue; 31 tot[o] += geth(v, o, od + 1); 32 if (tot[v] > oh) oh = tot[v], hs[o] = v; 33 } 34 return tot[o] + 1; 35 } 36 37 void mark(int o, int ot) { 38 now++, top[o] = ot, num[o] = now, lik[now] = o; 39 if (!hs[o]) return; 40 mark(hs[o], ot); 41 for (int x = h[o]; x; x = e[x].next) { 42 int v = e[x].v; 43 if (v != hs[o] && v != f[o]) mark(v, v); 44 } 45 } 46 47 void build(int o, int l, int r) { 48 if (l == r) { 49 t[o] = (Tree) {w[lik[l]], w[lik[l]]}; 50 return; 51 } 52 int m = (l + r) >> 1; 53 build(o << 1, l, m), build(o << 1 | 1, m + 1, r); 54 t[o] = (Tree) {max(t[o << 1].m, t[o << 1 | 1].m), t[o << 1].s + t[o << 1 | 1].s}; 55 } 56 57 void upd(int o, int l, int r, int x, int w) { 58 if (l == r) { 59 t[o].m += w, t[o].s += w; 60 return; 61 } 62 int m = (l + r) >> 1; 63 if (x <= m) upd(o << 1, l, m, x, w); 64 else upd(o << 1 | 1, m + 1, r, x, w); 65 t[o] = (Tree) {max(t[o << 1].m, t[o << 1 | 1].m), t[o << 1].s + t[o << 1 | 1].s}; 66 } 67 68 int quem(int o, int l, int r, int ql, int qr) { 69 int m = (l + r) >> 1, res = -INF; 70 if (ql <= l && r <= qr) return t[o].m; 71 if (ql <= m) res = max(res, quem(o << 1, l, m, ql, qr)); 72 if (qr > m) res = max(res, quem(o << 1 | 1, m + 1, r, ql, qr)); 73 return res; 74 } 75 76 int qmax() { 77 int x = top[u], y = top[v], ans = -INF; 78 while (x != y) { 79 if (d[x] < d[y]) swap(x, y), swap(u, v); 80 ans = max(ans, quem(1, 1, n, num[x], num[u])); 81 u = f[x], x = top[u]; 82 } 83 if (d[u] > d[v]) swap(u, v); 84 return max(ans, quem(1, 1, n, num[u], num[v])); 85 } 86 87 int ques(int o, int l, int r, int ql, int qr) { 88 int m = (l + r) >> 1, res = 0; 89 if (ql <= l && r <= qr) return t[o].s; 90 if (ql <= m) res += ques(o << 1, l, m, ql, qr); 91 if (qr > m) res += ques(o << 1 | 1, m + 1, r, ql, qr); 92 return res; 93 } 94 95 int qsum() { 96 int x = top[u], y = top[v], ans = 0; 97 while (x != y) { 98 if (d[x] < d[y]) swap(x, y), swap(u, v); 99 ans += ques(1, 1, n, num[x], num[u]); 100 u = f[x], x = top[u]; 101 } 102 if (d[u] > d[v]) swap(u, v); 103 return ans + ques(1, 1, n, num[u], num[v]); 104 } 105 106 int main() { 107 scanf("%d", &n); 108 for (int i = 1; i <= n - 1; i++) scanf("%d %d", &u, &v), add(u, v); 109 for (int i = 1; i <= n; i++) scanf("%d", &w[i]); 110 geth(1, 0, 1), mark(1, 1), build(1, 1, n); 111 scanf("%d", &q); 112 for (int i = 1; i <= q; i++) { 113 scanf("%s %d %d", ch, &u, &v); 114 if (ch[1] == 'H') upd(1, 1, n, num[u], v - w[u]), w[u] = v; 115 else printf("%d\n", ch[1] == 'S' ? qsum() : qmax()); 116 } 117 return 0; 118 }