[學習筆記]整體DP


問題:

有一些問題,通常見於二維的DP,另一維記錄當前x的信息,但是這一維過大無法開下,O(nm)也無法通過。

但是如果發現,對於x,在第二維的一些區間內,取值都是相同的,並且這樣的區間是有限個,就可以批量處理

 

思想:

通過動態開點線段樹維護第二維,

如果某個節點沒有兒子,那么這個節點區間都是同一個權值。

也即,一個節點是空節點,那么這個節點所有的值和父親的值都一致。(其實它的兄弟也是空節點的)

對於序列的問題,

可以直接掃過去,修改某些位置的點。

或者線段樹合並。

對於樹上的問題,

線段樹合並。

 

實現:

主要考慮什么時候線段樹合並停止。以及pushdown的標記問題。

當x都沒有兒子或者y都沒有兒子時候,整個x的區間或整個y的區間都是同一個值,可以直接計算貢獻轉移過來(這個必須支持,否則不能整體DP)。

否則,pushdown,進行遞歸

pushdown時候建立新的兒子(如果之前沒有)。

空間復雜度和時間復雜度基本一致。O(nlogn)

 

只要滿足,在x都沒有兒子或者y都沒有兒子時候,可以快速合並然后return,那么就可以整體DP了。

 

例題1:[九省聯考2018]秘密襲擊coat

例題2:

 

$dp[x][c]=\Pi (sumy-dp[y][c])$sumy表示y的所有dp[y][*]的和

在x都沒有兒子或者y都沒有兒子時候,我們要么知道每個x的值,要么知道每個y的值。

在x都沒有兒子時候,把y的節點內每個數乘-1再加sumy,再乘上x區間的值。

y都沒有兒子時候,直接用(sumy-val)乘給x即可。

code:

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define fi first
#define se second
#define mk(a,b) make_pair(a,b)
#define numb (ch^'0')
#define pb push_back
#define solid const auto &
#define enter cout<<endl
#define pii pair<int,int>
using namespace std;
typedef long long ll;
template<class T>il void rd(T &x){
    char ch;x=0;bool fl=false;while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);(fl==true)&&(x=-x);}
template<class T>il void output(T x){if(x/10)output(x/10);putchar(x%10+'0');}
template<class T>il void ot(T x){if(x<0) putchar('-'),x=-x;output(x);putchar(' ');}
template<class T>il void prt(T a[],int st,int nd){for(reg i=st;i<=nd;++i) ot(a[i]);putchar('\n');}
namespace Modulo{
const int mod=998244353;
int ad(int x,int y){return (x+y)>=mod?x+y-mod:x+y;}
void inc(int &x,int y){x=ad(x,y);}
int mul(int x,int y){return (ll)x*y%mod;}
void inc2(int &x,int y){x=mul(x,y);}
int qm(int x,int y=mod-2){int ret=1;while(y){if(y&1) ret=mul(x,ret);x=mul(x,x);y>>=1;}return ret;}
}
using namespace Modulo;
namespace Miracle{
const int N=2e5+5;
int n,m,k;
struct node{
    int nxt,to;
}e[2*N];
int hd[N],cnt;
void add(int x,int y){
    e[++cnt].nxt=hd[x];
    e[cnt].to=y;
    hd[x]=cnt;
}
#define mid ((l+r)>>1)
struct tr{
    int sum,mul,ad;
    int ls,rs,val;
    void op(){
        cout<<"SUM "<<sum<<" MUL "<<mul<<" AD "<<ad<<endl;
    }
}t[20000000+3];
int tot,S;
vector<int>no[N];
int rt[N];
int nc(){
    ++tot;
    t[tot].sum=0;t[tot].mul=1;t[tot].ad=0;
    t[tot].ls=t[tot].rs=0;t[tot].val=0;
    return tot;
}
void tag(int x,int l,int r,int ml,int aa){
    // cout<<" tag "<<x<<" l "<<l<<" r "<<r<<" ml "<<ml<<" ad "<<aa<<endl;
    // t[x].op();
    t[x].sum=mul(t[x].sum,ml);
    t[x].sum=ad(t[x].sum,mul(r-l+1,aa));
    t[x].val=ad(mul(t[x].val,ml),aa);
    t[x].mul=mul(t[x].mul,ml);
    t[x].ad=ad(mul(t[x].ad,ml),aa);
}
void pushup(int x){
    t[x].sum=ad(t[t[x].ls].sum,t[t[x].rs].sum);
}
void pushdown(int x,int l,int r){
    if(!t[x].ls) t[x].ls=nc();
    if(!t[x].rs) t[x].rs=nc();
    tag(t[x].ls,l,mid,t[x].mul,t[x].ad);
    tag(t[x].rs,mid+1,r,t[x].mul,t[x].ad);
    t[x].mul=1;t[x].ad=0;
}
void upda(int &x,int l,int r,int p){
    // cout<<" pp "<<p<<" x "<<x<<" l "<<l<<" r "<<r<<" sm "<<t[x].sum<<" mul "<<t[x].mul<<" ad "<<t[x].ad<<endl;
    // cout<<" ls "<<t[x].ls<<" rs "<<t[x].rs<<endl; 
    if(!x) x=nc();
    if(l==r){
        // cout<<" ss "<<t[x].sum<<endl;
        t[x].sum=0;
        t[x].val=0;
        return;
    }
    pushdown(x,l,r);
    if(p<=mid) upda(t[x].ls,l,mid,p);
    else upda(t[x].rs,mid+1,r,p);
    pushup(x);
}
int merge(int x,int y,int l,int r){
    if(!t[x].ls&&!t[x].rs){
        swap(x,y);
        int v=t[y].val;
        tag(x,l,r,mod-1,S);
        tag(x,l,r,v,0);
    }else if(!t[y].ls&&!t[y].rs){
        int v=t[y].val;
        tag(x,l,r,ad(S,mod-v),0);
    }else{
        pushdown(x,l,r);pushdown(y,l,r);
        t[x].ls=merge(t[x].ls,t[y].ls,l,mid);
        t[x].rs=merge(t[x].rs,t[y].rs,mid+1,r);
        pushup(x);
    }
    return x;//warining!!
}
void dfs(int x,int fa){
    rt[x]=nc();
    tag(rt[x],1,m,1,1);
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa) continue;
        dfs(y,x);
        S=t[rt[y]].sum;
        rt[x]=merge(rt[x],rt[y],1,m);
        // cout<<y<<" back "<<x<<" : "<<" sum "<<t[rt[x]].sum<<endl;
    }
    for(solid c:no[x]){
        upda(rt[x],1,m,c);
    }
    // cout<<x<<" : "<<" sum "<<t[rt[x]].sum<<endl;
}
int main(){
    rd(n);rd(m);rd(k);
    int x,y;
    for(reg i=1;i<n;++i){
        rd(x);rd(y);
        add(x,y);add(y,x);
    }
    for(reg i=1;i<=k;++i){
        rd(x);rd(y);
        no[x].push_back(y);
    }
    dfs(1,0);
    printf("%d",t[rt[1]].sum);
    return 0;
}

}
signed main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
*/
View Code

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM