\(O(nL)\)的\(DP\)很普及組吧。點減邊容斥,設\(f_{u,i}-1\)為在\(u\)子樹內選出一個連通塊,使得它包含\(u\)且最深點距離\(u\)為\(i\)的方案數,\(g_{u,i}\)表示選出一個連通塊,使得它包含\(u\)且不包含\(u\)子樹內除\(u\)以外的點,距離\(u\)最遠的點距離為\(i\)的方案數。
顯然\(f\)可以用長鏈剖分優化,也很普及組,就不說了。
\(g\)其實也是可以用長鏈剖分優化的。將當前點的\(g\)轉移到重兒子時像\(f\)一樣直接繼承,轉移到輕兒子時,我們發現最后計算答案只用到了所有的\(g_{u,L}\),所以只有\(g_{u,L-dep[u]}\sim g_{u,L}\)這部分的值是有用的,暴力轉移這一部分就行了,復雜度和長鏈剖分一樣分析。
求\(g\)時有一個問題,就是需要求一個點的所有子樹除掉某個子樹外的\(f\)的積。暴力的想法是維護可持久化線段樹,復雜度\(O(nlogn)\)很不優秀。不過可以發現可持久化是假的,只要可回退化就行了。具體的,在求\(f\)的dfs中,每次合並一個子樹時,我們記錄下來合並這個子樹時對\(f\)相關信息進行的修改。在求\(g\)的dfs中,我們反序遍歷每個點的兒子,這樣就可以棧序撤銷,維護每一個前綴的\(f\)的積,后綴再開另一個數組記錄一下就行了。
不過這樣還是需要線段樹。我們發現我們要支持的是全局加,后綴乘。我們可以對每個點記錄標記\(a,b\),表示真實值為\(a*\)存儲值+\(b\)。全局加時修改\(b\)即可,后綴乘時修改\(a,b\),並暴力修改沒有被后綴乘影響到的元素,乘一個逆元即可。如果\(a\)為\(0\),那么\(a\)不存在逆元,就只能將后綴乘變為后綴賦值為\(-\frac{b}{a}\)。所以還要記錄標記\(p,t\),表示所有\(>=p\)的位置被賦值為\(t\),每次修改一個元素時下放后綴賦值標記。這樣就能做到\(O(n)\)了。
要真正做到\(O(n)\),還需要線性求逆元。我們發現要求的逆元都是\(f_{u,dep[u]}\)的形式。把這些值排成一個數組記為\(a\),並求出\(a\)的前綴積\(b\)。我們算出\(b_n^{-1}\),遞推出所有\(b^{-1}\),再利用\(b\)和\(b^{-1}\)計算\(a^{-1}\)即可。
實現時有幾個trick:
1.修改本質都是對內存位置的修改,所以需要支持撤銷時只要保存內存中哪個位置原來的值是多少就行了。
2.stack空間占用很大,可以用list代替。
3.給\(g\)分配空間時可以讓\(g_u\)的起始位置指向\(p-(L-dep[u])\),其中\(p\)是當前內存池中第一個可用位置,這樣訪問\(g_{u,L-dep[u]}\)時就會訪問位置\(p\)了。
出題人真良心,代碼只要寫3.3K。
#include<bits/stdc++.h>
#define foo int a[N],ia[N],b[N],p[N],t[N];int trs(int u,int w){return 1ll*J(w,b[u])*ia[u]%M;}
using namespace std;const int N=1e6+9;const int M=998244353;int gi(){int x;cin>>x;return x;}int P(int a,int b){return a+b>=M?a+b-M:a+b;}int J(int a,int b){return a-b<0?a-b+M:a-b;}void I(int&a,int b){a=a+b>=M?a+b-M:a+b;}void K(int&a,int b){a=a-b<0?a-b+M:a-b;}int qpow(int a,int b){int ret=1;while(b){if(b&1)ret=1ll*ret*a%M;a=1ll*a*a%M,b>>=1;}return ret;}vector<int>E[N];int n,L,k,U[N],O[N],V[N],ans=0,s[N],m=0,ff[N*10],*f[N],*pf=ff,gg[N*10],*g[N],*pg=gg,h[N];namespace F{foo int val(int u,int i){i=min(i,U[u]);return P(1ll*a[u]*(i<p[u]?f[u][i]:t[u])%M,b[u]);}}namespace G{foo int val(int u,int i){return P(1ll*a[u]*(i<p[u]?g[u][i]:t[u])%M,b[u]);}}bool cmp(int a,int b){return U[a]>U[b];}void D1(int u,int fa){U[u]=-1,V[u]=1;for(auto v:E[u])if(v!=fa)D1(v,u),U[v]>U[u]?O[u]=v,U[u]=U[v]:0,V[u]=1ll*V[u]*V[v]%M;sort(E[u].begin(),E[u].end(),cmp);++U[u],I(V[u],1);}void New(int u){f[u]=pf;pf+=U[u]+2;pg+=U[u]+2;g[u]=pg-max(0,L-U[u]);pg+=U[u]+2;}struct op{int x,*p;};list<op>st[N];void S(int u,int &x){st[u].push_back((op){x,&x});}void undo(int u){op t;while(!st[u].empty())t=st[u].back(),*t.p=t.x,st[u].pop_back();}void D2(int u,int fa){using namespace F;int z=0;p[u]=U[u]+1;a[u]=ia[u]=b[u]=1;if(O[u])f[z=O[u]]=f[u]+1,D2(z,u),a[u]=a[z],ia[u]=ia[z],b[u]=b[z],p[u]=p[z]+1,t[u]=t[z],f[u][0]=trs(u,1);for(auto v:E[u])if(v!=fa&&v!=O[u]){z=v,New(v),D2(v,u);for(int i=0;i<=U[v]+1;i++){if(p[u]==i)S(v,p[u]),S(v,f[u][i]),f[u][p[u]++]=t[u];S(v,f[u][i]),f[u][i]=trs(u,1ll*val(u,i)*(i?val(v,i-1):1)%M);}if(U[u]>U[v]+1){int w=val(v,U[v]);if(w){S(v,a[u]),S(v,ia[u]),S(v,b[u]),a[u]=1ll*a[u]*w%M,b[u]=1ll*b[u]*w%M,ia[u]=1ll*ia[u]*V[v]%M;for(int i=0;i<=U[v]+1;i++)S(v,f[u][i]),f[u][i]=trs(u,1ll*val(u,i)*V[v]%M);}else S(v,p[u]),S(v,t[u]),p[u]=U[v]+1,t[u]=trs(u,0);}}if(z)S(z,b[u]);I(b[u],1);}void D3(int u,int fa){using namespace G;reverse(E[u].begin(),E[u].end());int sx=0,prd=1;h[0]=1;I(ans,qpow(1ll*J(F::val(u,L),1)*val(u,L)%M,k));if(fa)K(ans,qpow(1ll*J(F::val(u,L-1),1)*J(val(u,L),1)%M,k));for(auto v:E[u])if(v!=fa&&v!=O[u]){undo(v);p[v]=L+1;a[v]=ia[v]=1;for(int i=max(0,L-U[v]);i<=L;i++)g[v][i]=P(1ll*(i?val(u,i-1):0)*(i>1?1ll*F::val(u,i-1)*(i-2<=sx?h[i-2]:prd)%M:1)%M,1);for(int i=0;i<=U[v];i++)h[i]=1ll*(i>sx?prd:h[i])*F::val(v,i)%M;sx=U[v];prd=1ll*prd*F::val(v,U[v])%M;}int z;if(O[u]){g[z=O[u]]=g[u]-1;a[z]=a[u];b[z]=b[u];ia[z]=ia[u];t[z]=t[u];p[z]=p[u]+1,L-U[z]<=0?g[z][0]=trs(z,0):0;for(auto v:E[u])if(v!=fa&&v!=O[u]){p[z]=max(p[z],L-U[z]);for(int i=max(0,L-U[z]);i<=min(L,U[v]+2);i++){if(p[z]==i)g[z][p[z]++]=t[z];g[z][i]=trs(z,1ll*val(z,i)*(i>1?F::val(v,i-2):1)%M);}if(L>U[v]+2){int w=F::val(v,U[v]);if(w){a[z]=1ll*a[z]*w%M,b[z]=1ll*b[z]*w%M,ia[z]=1ll*ia[z]*V[v]%M;for(int i=max(0,L-U[z]);i<=min(L,U[v]+2);i++)g[z][i]=trs(z,1ll*val(z,i)*V[v]%M);}else p[z]=U[v]+2,t[z]=trs(z,0);}}I(b[z],1);D3(z,u);}for(auto v:E[u])if(v!=fa&&v!=O[u])D3(v,u);}int main(){n=gi(),L=gi(),k=gi();if(!L)return printf("%d\n",n),0;for(int i=1,u,v;i<n;i++)u=gi(),v=gi(),E[u].push_back(v),E[v].push_back(u);D1(1,0);s[0]=1;for(int i=1;i<=n;i++)if(V[i])++m,s[m]=1ll*s[m-1]*V[i]%M;s[m]=qpow(s[m],M-2);for(int i=n,t;i;i--)if(V[i])--m,t=V[i],V[i]=1ll*s[m+1]*s[m]%M,s[m]=1ll*s[m+1]*t%M;New(1),D2(1,0);G::a[1]=G::ia[1]=G::b[1]=1;G::p[1]=L+1;D3(1,0);printf("%d\n",ans);return 0;}
正常代碼:
//HNOIday1t1出題人nmsl
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
const int mod=998244353;
int gi() {
int x=0,o=1;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if(ch=='-') o=-1,ch=getchar();
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x*o;
}
int add(int a,int b) {
return a+b>=mod?a+b-mod:a+b;
}
int sub(int a,int b) {
return a-b<0?a-b+mod:a-b;
}
void inc(int &a,int b) {
a=a+b>=mod?a+b-mod:a+b;
}
void dec(int &a,int b) {
a=a-b<0?a-b+mod:a-b;
}
int qpow(int a,int b) {
int ret=1;
while(b) {
if(b&1) ret=1ll*ret*a%mod;
a=1ll*a*a%mod,b>>=1;
}
return ret;
}
vector<int> E[N];
int n,L,k,dep[N],son[N],inv[N],ans=0,s[N],m=0,ff[N*10],*f[N],*pf=ff,gg[N*10],*g[N],*pg=gg,h[N];
vector<int> vis[N];
#define foo int a[N],ia[N],b[N],p[N],t[N]; \
int trs(int u,int w) { \
return 1ll*sub(w,b[u])*ia[u]%mod; \
} \
namespace F {
foo
int val(int u,int i) {
i=min(i,dep[u]);
return add(1ll*a[u]*(i<p[u]?f[u][i]:t[u])%mod,b[u]);
}
}
namespace G {
foo
int val(int u,int i) {
return add(1ll*a[u]*(i<p[u]?g[u][i]:t[u])%mod,b[u]);
}
}
bool cmp(int a,int b) {
return dep[a]>dep[b];
}
void dfs1(int u,int fa) {
dep[u]=-1,inv[u]=1;
for(auto v:E[u])
if(v!=fa) dfs1(v,u),dep[v]>dep[u]?son[u]=v,dep[u]=dep[v]:0,inv[u]=1ll*inv[u]*inv[v]%mod;
if(son[u]) {
int mx=0;
for(auto v:E[u]) if(v!=fa&&v!=son[u]) mx=max(mx,dep[v]),vis[dep[v]].push_back(v);
E[u]=vector<int>(1,son[u]);
for(int i=mx;~i;i--) for(auto v:vis[i]) E[u].push_back(v);
for(int i=mx;~i;i--) vis[i].clear();
}
++dep[u],inc(inv[u],1);
}
void New(int u) {
f[u]=pf;pf+=dep[u]+2;pg+=dep[u]+2;g[u]=pg-max(0,L-dep[u]);pg+=dep[u]+2;
}
struct op { int x,*p; };
list<op> st[N];
void S(int u,int &x) {
st[u].push_back((op){x,&x});
}
void undo(int u) {
op t;
while(!st[u].empty()) t=st[u].back(),*t.p=t.x,st[u].pop_back();
}
void dfs2(int u,int fa) {
using namespace F;
int z=0;p[u]=dep[u]+1;a[u]=ia[u]=b[u]=1;
if(son[u]) f[z=son[u]]=f[u]+1,dfs2(z,u),a[u]=a[z],ia[u]=ia[z],b[u]=b[z],p[u]=p[z]+1,t[u]=t[z],f[u][0]=trs(u,1);
for(auto v:E[u])
if(v!=fa&&v!=son[u]) {
z=v,New(v),dfs2(v,u);
for(int i=0;i<=dep[v]+1;i++) {
if(p[u]==i) S(v,p[u]),S(v,f[u][i]),f[u][p[u]++]=t[u];
S(v,f[u][i]),f[u][i]=trs(u,1ll*val(u,i)*(i?val(v,i-1):1)%mod);
}
if(dep[u]>dep[v]+1) {
int w=val(v,dep[v]);
if(w) {
S(v,a[u]),S(v,ia[u]),S(v,b[u]),a[u]=1ll*a[u]*w%mod,b[u]=1ll*b[u]*w%mod,ia[u]=1ll*ia[u]*inv[v]%mod;
for(int i=0;i<=dep[v]+1;i++) S(v,f[u][i]),f[u][i]=trs(u,1ll*val(u,i)*inv[v]%mod);
}
else S(v,p[u]),S(v,t[u]),p[u]=dep[v]+1,t[u]=trs(u,0);
}
}
if(z) S(z,b[u]);inc(b[u],1);
}
void dfs3(int u,int fa) {
using namespace G;
reverse(E[u].begin(),E[u].end());
int sx=0,prd=1;h[0]=1;
inc(ans,qpow(1ll*sub(F::val(u,L),1)*val(u,L)%mod,k));
if(fa) dec(ans,qpow(1ll*sub(F::val(u,L-1),1)*sub(val(u,L),1)%mod,k));
for(auto v:E[u])
if(v!=fa&&v!=son[u]) {
undo(v);p[v]=L+1;a[v]=ia[v]=1;
for(int i=max(0,L-dep[v]);i<=L;i++) g[v][i]=add(1ll*(i?val(u,i-1):0)*(i>1?1ll*F::val(u,i-1)*(i-2<=sx?h[i-2]:prd)%mod:1)%mod,1);
for(int i=0;i<=dep[v];i++) h[i]=1ll*(i>sx?prd:h[i])*F::val(v,i)%mod;
sx=dep[v];prd=1ll*prd*F::val(v,dep[v])%mod;
}
int z;
if(son[u]) {
g[z=son[u]]=g[u]-1;a[z]=a[u];b[z]=b[u];ia[z]=ia[u];t[z]=t[u];p[z]=p[u]+1,L-dep[z]<=0?g[z][0]=trs(z,0):0;
for(auto v:E[u])
if(v!=fa&&v!=son[u]) {
p[z]=max(p[z],L-dep[z]);
for(int i=max(0,L-dep[z]);i<=min(L,dep[v]+2);i++) {
if(p[z]==i) g[z][p[z]++]=t[z];
g[z][i]=trs(z,1ll*val(z,i)*(i>1?F::val(v,i-2):1)%mod);
}
if(L>dep[v]+2) {
int w=F::val(v,dep[v]);
if(w) {
a[z]=1ll*a[z]*w%mod,b[z]=1ll*b[z]*w%mod,ia[z]=1ll*ia[z]*inv[v]%mod;
for(int i=max(0,L-dep[z]);i<=min(L,dep[v]+2);i++) g[z][i]=trs(z,1ll*val(z,i)*inv[v]%mod);
}
else p[z]=dep[v]+2,t[z]=trs(z,0);
}
}
inc(b[z],1);dfs3(z,u);
}
for(auto v:E[u]) if(v!=fa&&v!=son[u]) dfs3(v,u);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
#endif
n=gi(),L=gi(),k=gi();
if(L==0) return printf("%d\n",n),0;
for(int i=1,u,v;i<n;i++) u=gi(),v=gi(),E[u].push_back(v),E[v].push_back(u);
dfs1(1,0);
s[0]=1;for(int i=1;i<=n;i++) if(inv[i]) ++m,s[m]=1ll*s[m-1]*inv[i]%mod;
s[m]=qpow(s[m],mod-2);
for(int i=n,t;i;i--) if(inv[i]) --m,t=inv[i],inv[i]=1ll*s[m+1]*s[m]%mod,s[m]=1ll*s[m+1]*t%mod;
New(1),dfs2(1,0);G::a[1]=G::ia[1]=G::b[1]=1;G::p[1]=L+1;dfs3(1,0);
printf("%d\n",ans);
return 0;
}