【bzoj5210】最大連通子塊和 樹鏈剖分+線段樹+可刪除堆維護樹形動態dp


題目描述

給出一棵n個點、以1為根的有根樹,點有點權。要求支持如下兩種操作:
M x y:將點x的點權改為y;
Q x:求以x為根的子樹的最大連通子塊和。
其中,一棵子樹的最大連通子塊和指的是:該子樹所有子連通塊的點權和中的最大值
(本題中子連通塊包括空連通塊,點權和為0)。

輸入

第一行兩個整數n、m,表示樹的點數以及操作的數目。
第二行n個整數,第i個整數w_i表示第i個點的點權。
接下來的n-1行,每行兩個整數x、y,表示x和y之間有一條邊相連。
接下來的m行,每行輸入一個操作,含義如題目所述。保證操作為M x y或Q x之一。
1≤n,m≤200000 ,任意時刻 |w_i|≤10^9 。

輸出

對於每個Q操作輸出一行一個整數,表示詢問子樹的最大連通子塊和。

樣例輸入

5 4
3 -2 0 3 -1
1 2
1 3
4 2
2 5
Q 1
M 4 1
Q 1
Q 2

樣例輸出

4
3
1


題解

樹鏈剖分+線段樹+可刪除堆維護樹形動態dp

如果dp是靜態的,設 $f[i]$ 表示以 $i$ 為根的子樹中,選出 包括 $i$ 的連通子塊 或 空塊 的最大點權和。那么有 $f[i]=\text{max}(v[i]+\sum\limits_{i\to j}f[j],0)$ 。所求即是子樹內所有點的 $f$ 值的最大值。

當這個dp在序列上進行時,容易轉化為最大連續子段和的形式。

當這個dp在樹上進行時,考慮將這棵樹輕重鏈剖分,轉化為序列問題。

設 $y$ 為 $x$ 的重兒子,所有 $x$ 的輕兒子的 $f$ 值加上 $v[x]$ 為 $g[x]$ ,那么有 $f[x]=\text{max}(f[y]+g[x],0)$ 。

這個形式類似於最小連續子段和中的最小前綴和。使用線段樹維護最小前綴和(在重鏈這一段區間的某位置選出一個點使得 不選鏈頂到該點父親,其余選最大的 最大)及總和(都不選)。線段樹的葉子節點的最小前綴和和總和都是 $g$ 。

修改時,首先 $v[x]$ 修改導致 $g[x]$ 修改;然后使鏈頂的 $f$ 值修改,影響鏈頂父親的 $g$ ,再不斷修改即可。

查詢時,一個點的 $f$ 值就是該點到鏈底節點的最小前綴和。

然而答案是子樹內所有 $f$ 的最大值,因此不能僅僅維護最小前綴和。

考慮重鏈上的部分:其實相當於每一個后綴的前綴中最大的那個,即子段中最大的那個。因此維護最大連續子段和即可直接得出鏈上所有點的 $f$ 的最大值。

考慮輕鏈上的部分:一條重鏈上的答案對鏈頂父親有貢獻,將這個答案加到鏈頂父親對應葉子節點的最大連續子段和即可。即:一個點對應葉子節點初始的最大連續子段和為:該節點的 $v$ 值與該節點輕兒子所在重鏈的最大連續子段和的最大值。我們對每個節點再維護這個最大值即可。由於要支持修改、查詢最值,因此使用可刪除堆(或者STL-set)。

這樣查詢時查詢該點到鏈底的最大連續子段和就是答案了。

修改的時間復雜度為 $O(\log^2n)$ ,詢問的時間復雜度為 $O(\log n)$ 。

#include <queue>
#include <cstdio>
#include <algorithm>
#define N 200010
#define lson l , mid , x << 1
#define rson mid + 1 , r , x << 1 | 1
using namespace std;
typedef long long ll;
struct data
{
	ll sum , ls , rs , ts;
	inline friend data operator+(const data &a , const data &b)
	{
		data ans;
		ans.sum = a.sum + b.sum;
		ans.ls = max(a.ls , a.sum + b.ls);
		ans.rs = max(b.rs , b.sum + a.rs);
		ans.ts = max(a.rs + b.ls , max(a.ts , b.ts));
		return ans;
	}
}a[N << 2];
struct heap
{
	priority_queue<ll> A , B;
	inline void push(ll x) {A.push(x);}
	inline void del(ll x) {B.push(x);}
	inline ll top()
	{
		while(!B.empty() && A.top() == B.top()) A.pop() , B.pop();
		return A.top();
	}
}q[N];
int n , v[N] , head[N] , to[N << 1] , next[N << 1] , cnt , fa[N] , si[N] , bl[N] , end[N] , pos[N] , tot;
ll f[N] , ms[N] , w[N];
char str[5];
inline void add(int x , int y)
{
	to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt;
}
void dfs1(int x)
{
	int i;
	si[x] = 1;
	for(i = head[x] ; i ; i = next[i])
		if(to[i] != fa[x])
			fa[to[i]] = x , dfs1(to[i]) , si[x] += si[to[i]];
}
void dfs2(int x , int c)
{
	int i , k = 0;
	bl[x] = c , pos[x] = ++tot , w[pos[x]] = v[x];
	for(i = head[x] ; i ; i = next[i])
		if(to[i] != fa[x] && si[to[i]] > si[k])
			k = to[i];
	if(k)
	{
		dfs2(k , c) , f[x] = f[k] , ms[x] = ms[k] , end[x] = end[k];
		for(i = head[x] ; i ; i = next[i])
			if(to[i] != fa[x] && to[i] != k)
				dfs2(to[i] , to[i]) , w[pos[x]] += f[to[i]] , q[pos[x]].push(ms[to[i]]);
	}
	else end[x] = x;
	f[x] = max(f[x] + w[pos[x]] , 0ll) , ms[x] = max(ms[x] , max(f[x] , q[pos[x]].top()));
}
void build(int l , int r , int x)
{
	if(l == r)
	{
		a[x].sum = w[l] , a[x].ls = a[x].rs = max(w[l] , 0ll) , a[x].ts = max(w[l] , q[l].top());
		return;
	}
	int mid = (l + r) >> 1;
	build(lson) , build(rson);
	a[x] = a[x << 1] + a[x << 1 | 1];
}
void fix(int p , int l , int r , int x)
{
	if(l == r)
	{
		a[x].sum = w[l] , a[x].ls = a[x].rs = max(w[l] , 0ll) , a[x].ts = max(w[l] , q[l].top());
		return;
	}
	int mid = (l + r) >> 1;
	if(p <= mid) fix(p , lson);
	else fix(p , rson);
	a[x] = a[x << 1] + a[x << 1 | 1];
}
data query(int b , int e , int l , int r , int x)
{
	if(b <= l && r <= e) return a[x];
	int mid = (l + r) >> 1;
	if(e <= mid) return query(b , e , lson);
	else if(b > mid) return query(b , e , rson);
	else return query(b , e , lson) + query(b , e , rson);
}
void modify(int x , int z)
{
	data a , b;
	bool flag = 0;
	a.ls = w[pos[x]] , b.ls = w[pos[x]] - v[x] + z , v[x] = z;
	while(x)
	{
		w[pos[x]] += b.ls - a.ls;
		if(flag) q[pos[x]].del(a.ts) , q[pos[x]].push(b.ts);
		a = query(pos[bl[x]] , pos[end[x]] , 1 , n , 1);
		fix(pos[x] , 1 , n , 1);
		b = query(pos[bl[x]] , pos[end[x]] , 1 , n , 1);
		x = fa[bl[x]] , flag = 1;
	}
}
int main()
{
	int m , i , x , y;
	scanf("%d%d" , &n , &m);
	for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &v[i]) , q[i].push(0);
	for(i = 1 ; i < n ; i ++ ) scanf("%d%d" , &x , &y) , add(x , y) , add(y , x);
	dfs1(1) , dfs2(1 , 1);
	build(1 , n , 1);
	while(m -- )
	{
		scanf("%s%d" , str , &x);
		if(str[0] == 'M') scanf("%d" , &y) , modify(x , y);
		else printf("%lld\n" , query(pos[x] , pos[end[x]] , 1 , n , 1).ts);
	}
	return 0;
}

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM