[JSOI2019]神經網絡(樹形DP+容斥+生成函數)


首先可以把題目轉化一下:把樹拆成若干條鏈,每條鏈的顏色為其所在的樹的顏色,然后排放所有的鏈成環,求使得相鄰位置顏色不同的排列方案數。

然后本題分為兩個部分:將一棵樹分為1~n條不相交的鏈的方案數;將這些鏈安排順序使得不存在兩條相鄰的鏈來自同一棵樹。

第一部分顯然可以O(n2)樹形DP,f[i][j][0/1/2]表示i及其子樹j條鏈,i向兒子連出0/1/2條邊的方案數,然后直接背包DP即可。看似O(n3)的樹形背包DP其實是O(n2)的。證明復雜度:其實DP時只循環到sz[u]/sz[v]即可,然后可以把每個轉移視為兒子v內子樹的每個節點和節點u內v外節點組成的點對,於是全部DP完就是枚舉了所有的點對,復雜度顯然O(n2)。

第二部分,考慮n個點的樹划分成i條鏈的方案是f[i],如果不考慮環只考慮鏈其對應的指數生成函數為Σf[i]i!(Σ(-1)i-jC(i-1,i-j)xj/j!),其中i∈[1,n],j∈[1,i]。拓展到環上,欽定一棵樹作為開頭,如果該顏色有i條鏈,則被算了i次,然后其指數生成函數為:Σf[i](i-1)!(Σ(-1)i-jC(i-1,i-j)xj-1/(j-1)!),其中i∈[1,n],j∈[1,i]。減去首尾同色后,生成函數是這樣的:Σf[i](i-1)!(Σ(-1)i-jC(i-1,i-j)xj-2/(j-2)!),其中i∈[2,n],j∈[2,i]。然后暴力卷積即可。

#include<bits/stdc++.h>
using namespace std;
const int N=5005,mod=998244353;
int n,m,sum,ans,fac[N],inv[N],sz[N],f[N][N][3],g[N],tmp[N][3],dp[310][N],b[N];
vector<int>G[N];
int qpow(int a,int b)
{
    int ret=1;
    while(b)
    {
        if(b&1)ret=1ll*ret*a%mod;
        a=1ll*a*a%mod,b>>=1;
    }
    return ret;
}
void dfs(int u,int fa)
{
    sz[u]=1,f[u][1][0]=1;
    for(int i=0;i<G[u].size();i++)
    if(G[u][i]!=fa)
    {
        int v=G[u][i];
        dfs(v,u);
        for(int j=0;j<=sz[u]+sz[v];j++)tmp[j][0]=tmp[j][1]=tmp[j][2]=0;
        for(int j=1;j<=sz[u];j++)
        for(int k=1;k<=sz[v];k++)
        {
            tmp[j+k][0]=(tmp[j+k][0]+1ll*f[u][j][0]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod;
            tmp[j+k-1][1]=(tmp[j+k-1][1]+1ll*f[u][j][0]*(f[v][k][0]+f[v][k][1]))%mod;
            tmp[j+k][1]=(tmp[j+k][1]+1ll*f[u][j][1]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod;
            tmp[j+k-1][2]=(tmp[j+k-1][2]+1ll*f[u][j][1]*(f[v][k][0]+f[v][k][1]))%mod;
            tmp[j+k][2]=(tmp[j+k][2]+1ll*f[u][j][2]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod;
        }
        sz[u]+=sz[v];
        for(int j=1;j<=sz[u];j++)f[u][j][0]=tmp[j][0],f[u][j][1]=tmp[j][1],f[u][j][2]=tmp[j][2];
    }
}
int C(int a,int b){return a<b?0:1ll*fac[a]*inv[b]%mod*inv[a-b]%mod;}
int S(int a,int b){return (!a&&!b)?1:1ll*fac[a]*C(a-1,a-b)%mod;}
int main()
{
    fac[0]=1;for(int i=1;i<=5000;i++)fac[i]=1ll*fac[i-1]*i%mod;
    for(int i=0;i<=5000;i++)inv[i]=qpow(fac[i],mod-2);
    scanf("%d",&m);
    dp[0][0]=1;
    for(int p=1;p<=m;p++)
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++)G[i].clear();
        for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),G[x].push_back(y),G[y].push_back(x);
        for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
        f[i][j][0]=f[i][j][1]=f[i][j][2]=0;
        dfs(1,0);
        memset(g,0,sizeof g);
        for(int i=1;i<=n;i++)g[i]=(f[1][i][0]+2ll*f[1][i][1]+2ll*f[1][i][2])%mod;
        if(p!=m)
        {
            memset(b,0,sizeof b);
            for(int j=1;j<=n;j++)
            if(g[j])for(int k=0,t=1;k<=j;k++,t=mod-t)
            b[j-k]=(b[j-k]+1ll*t*S(j,j-k)%mod*g[j])%mod;
            for(int i=0;i<=sum;i++)
            if(dp[p-1][i])for(int j=0;j<=n;j++)
            dp[p][i+j]=(dp[p][i+j]+1ll*C(i+j,j)*b[j]%mod*dp[p-1][i])%mod;
        }
        else{
            memset(b,0,sizeof b);
            for(int j=1;j<=n;j++)
            if(g[j])for(int k=0,t=1;k<j;k++,t=mod-t)
            b[j-1-k]=(b[j-1-k]+1ll*t*S(j-1,j-k-1)%mod*g[j])%mod;
            for(int i=0;i<=sum;i++)
            if(dp[p-1][i])for(int j=0;j<=n;j++)
            ans=(ans+1ll*C(i-2+j,j)*b[j]%mod*dp[p-1][i])%mod;
        }
        sum+=n;
    }
    printf("%d",ans);
}
View Code

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM