Luogu P4643 【模板】動態dp


題目鏈接

Luogu P4643

題解

貓錕在WC2018講的黑科技——動態DP,就是一個畫風正常的DP問題再加上一個動態修改操作,就像這道題一樣。(這道題也是PPT中的例題)

動態DP的一個套路是把DP轉移方程寫成矩陣乘法,然后用線段樹(樹上的話就是樹剖)維護矩陣,這樣就可以做到修改了。

注意這個“矩陣乘法”不一定是我們常見的那種乘法和加法組成的矩陣乘法。設\(A * B = C\),常見的那種矩陣乘法是這樣的:

\[C_{i, j} = \sum_{k = 1}^{n} A_{i, k} * B_{k, j} \]

而這道題中的矩陣乘法是這樣的:

\[C_{i, j} = \max_{k = 1}^{n} (A_{i, k} + B_{k, j}) \]

這就相當於常見矩陣乘法中的加法變成了max,乘法變成了加法。類似於乘法和加法的五種運算律,這兩種變化也滿足“加法交換律”、“加法結合律”、“max交換律”、“max結合律”和“加法分配律“。那么這種矩陣乘法顯然也滿足矩陣乘法結合律,就像正常的矩陣乘法一樣,可以用線段樹維護。

接下來我們來構造矩陣。首先研究DP方程。

就像“沒有上司的舞會”一樣,\(f_{i, 0}\)表示子樹\(i\)中不選\(i\)的最大權獨立集大小,\(f_{i, 1}\)表示子樹\(i\)中選\(i\)的最大權獨立集大小。

但這是動態DP,我們需要樹鏈剖分。假設我們已經完成了樹鏈剖分,剖出來的某條重鏈看起來就像這樣,右邊的是在樹上深度較大的點:

此時,比這條重鏈的top深度大且不在這條重鏈上的點的DP值都是已經求出來的(這可以做到)。我們把它們的貢獻,都統一於它們在這條重鏈上對應的那個祖先上。

具體來說,設\(g_{i, 0}\)表示不選\(i\)時,\(i\)不在鏈上的子孫的最大權獨立集大小,\(g_{i, 1}\)表示選\(i\)時,\(i\)不在鏈上的子孫再加上\(i\)自己的最大權獨立集大小。

假如\(i\)右面的點是\(i + 1\), 那么可以得出:

\[f_{i, 0} = g_{i, 0} + \max(f_{i + 1, 0}, f_{i + 1, 1}) \]

\[f_{i, 1} = g_{i, 1} + f_{i + 1, 0} \]

矩陣也就可以構造出來了:

\[\begin{bmatrix}g_{i, 0} & g_{i, 0} \\g_{i, 1} & 0\end{bmatrix} * \begin{bmatrix}f_{i + 1, 0} \\ f_{i + 1, 1}\end{bmatrix} = \begin{bmatrix}f_{i, 0} \\ f_{i, 1}\end{bmatrix} \]

讀者可以動筆驗證一下。(注意我們在這里用的“新矩陣乘法”的規則:原來的乘變成加,加變成取max。)

那么基本思路就很清楚了:樹剖,維護區間矩陣乘積。修改的時候,對於被修改節點到根節點路徑上的每個重鏈(由下到上),先進行單點修改,然后求出這條重鏈的\(top\)在修改之后的\(f\)值,然后繼續修改top所在重鏈。

每次答案就是節點\(1\)\(f\)值。

代碼

代碼略丑,見諒……

#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <queue>
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
    char c;
    bool op = 0;
    while(c = getchar(), c < '0' || c > '9')
	if(c == '-') op = 1;
    x = c - '0';
    while(c = getchar(), c >= '0' && c <= '9')
	x = x * 10 + c - '0';
    if(op) x = -x;
}
template <class T>
void write(T x){
    if(x < 0) putchar('-'), x = -x;
    if(x >= 10) write(x / 10);
    putchar('0' + x % 10);
}

const int N = 100005;
int n, m, a[N];
int ecnt, adj[N], nxt[2*N], go[2*N];
int fa[N], son[N], sze[N], top[N], idx[N], pos[N], tot, ed[N];
ll f[N][2];

struct matrix {
    ll g[2][2];
    matrix(){
	memset(g, 0, sizeof(g));
    }
    matrix operator * (const matrix &b) const {
	matrix c;
	for(int i = 0; i < 2; i++)
	    for(int j = 0; j < 2; j++)
		for(int k = 0; k < 2; k++)
		    c.g[i][j] = max(c.g[i][j], g[i][k] + b.g[k][j]);
	return c;
    }
} val[N], data[4*N];

void add(int u, int v){
    go[++ecnt] = v;
    nxt[ecnt] = adj[u];
    adj[u] = ecnt;
}

void init(){
    static int que[N];
    que[1] = 1;
    for(int ql = 1, qr = 1; ql <= qr; ql++)
	for(int u = que[ql], e = adj[u], v; e; e = nxt[e])
	    if((v = go[e]) != fa[u])
		fa[v] = u, que[++qr] = v;
    for(int qr = n, u; qr; qr--){
	sze[u = que[qr]]++;
	sze[fa[u]] += sze[u];
	if(sze[u] > sze[son[fa[u]]])
	    son[fa[u]] = u;
    }
    for(int ql = 1, u; ql <= n; ql++)
	if(!top[u = que[ql]]){
	    for(int v = u; v; v = son[v])
		top[v] = u, idx[pos[v] = ++tot] = v;
	    ed[u] = tot;
	}
    for(int qr = n, u; qr; qr--){
	u = que[qr];
	f[u][1] = max(0, a[u]);
	for(int e = adj[u], v; e; e = nxt[e])
	    if(v = go[e], v != fa[u]){
		f[u][0] += max(f[v][0], f[v][1]);
		f[u][1] += f[v][0];
	    }
    }
}

void build(int k, int l, int r){
    if(l == r){
	ll g0 = 0, g1 = a[idx[l]];
	for(int u = idx[l], e = adj[u], v; e; e = nxt[e])
	    if((v = go[e]) != fa[u] && v != son[u])
		g0 += max(f[v][0], f[v][1]), g1 += f[v][0];
	data[k].g[0][0] = data[k].g[0][1] = g0;
	data[k].g[1][0] = g1;
	val[l] = data[k];
	return;
    }
    int mid = (l + r) >> 1;
    build(k << 1, l, mid);
    build(k << 1 | 1, mid + 1, r);
    data[k] = data[k << 1] * data[k << 1 | 1];
}
void change(int k, int l, int r, int p){
    if(l == r){
	data[k] = val[l];
	return;
    }
    int mid = (l + r) >> 1;
    if(p <= mid) change(k << 1, l, mid, p);
    else change(k << 1 | 1, mid + 1, r, p);
    data[k] = data[k << 1] * data[k << 1 | 1];
}
matrix query(int k, int l, int r, int ql, int qr){
    if(ql <= l && qr >= r) return data[k];
    int mid = (l + r) >> 1;
    if(qr <= mid) return query(k << 1, l, mid, ql, qr);
    if(ql > mid) return query(k << 1 | 1, mid + 1, r, ql, qr);
    return query(k << 1, l, mid, ql, qr) * query(k << 1 | 1, mid + 1, r, ql, qr);
}
matrix ask(int u){
    return query(1, 1, n, pos[top[u]], ed[top[u]]);
}
void path_change(int u, int x){
    val[pos[u]].g[1][0] += x - a[u];
    a[u] = x;
    matrix od, nw;
    while(u){
	od = ask(top[u]);
	change(1, 1, n, pos[u]);
	nw = ask(top[u]);
	u = fa[top[u]];
	val[pos[u]].g[0][0] += max(nw.g[0][0], nw.g[1][0]) - max(od.g[0][0], od.g[1][0]);
	val[pos[u]].g[0][1] = val[pos[u]].g[0][0];
	val[pos[u]].g[1][0] += nw.g[0][0] - od.g[0][0];
    }
}

int main(){

    read(n);
    read(m);
    for(int i = 1; i <= n; i++) read(a[i]);
    for(int i = 1, u, v; i < n; i++)
	read(u), read(v), add(u, v), add(v, u);
    init();
    build(1, 1, n);
    int u, x;
    matrix t;
    while(m--){
	read(u), read(x);
	path_change(u, x);
	t = ask(1);
	write(max(t.g[0][0], t.g[1][0])), enter;
    }

    return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM