【XSY2843】「地底薔薇」 NTT什么的 擴展拉格朗日反演


題目大意

  給定集合\(S\),請你求出\(n\)個點的“所有極大點雙連通分量的大小都在\(S\)內”的不同簡單無向連通圖的個數對\(998244353\)取模的結果。

  \(n\leq {10}^5,(m=\sum_{x\in S})\leq {10}^5\)

題解

  首先你要會求\(n\)個點帶標號有根簡單無向圖的個數。bzoj3456就是求這個東西。

  記\(H(x)\)為帶標號有根簡單無向圖個數的EGF。

  記\(b_i\)\(i+1\)個點的帶標號點雙個數,\(B(x)=\sum_{i\geq 0}\frac{b_i}{i!}x^i\)

  考慮一個有根連通圖是長怎樣的。

  先把根刪掉,然后整個圖會分為很多個連通塊。每個連通塊內部都有一些點和根在同一個點雙內,把點雙里面的所有邊刪掉之后整個點雙會分成很多個以這些點為根的連通圖。我們枚舉單個點雙還剩下多少個點,則單個連通塊的答案是

\[\sum_{i\geq 1}\frac{b_i}{i!}H^i(x)=B(H(x)) \]

  把所有連通塊合在一起,有

\[H(x)=xe^{B(H(x))} \]

  現在我們知道\(H(x)\),要求\(B(x)\)中某些項的系數。

  記\(H^{-1}(x)\)\(H(x)\)的復合逆。

  然后就有

\[\begin{align} H(x)&=xe^{B(H(x))}\\ x&=\frac{H(x)}{e^{B(H(x))}}\\ H^{-1}(x)&=\frac{x}{e^{B(x)}}\\ B(x)&=\ln \frac{x}{H^{-1}(x)} \end{align} \]

  如果直接用拉格朗日反演求\(H^{-1}(x)\)的系數再求\(B(x)\)的話,要求出全部\(O(n)\)項,要花費\(O(n^2)\)的時間。這太慢了。

  有個東西叫擴展拉格朗日反演:

\[[x^n]G(F^{-1}(x))=\frac{1}{n}[x^{-1}]\frac{dG(x)}{dx}\frac{1}{F^n(x)} \]

  我們要構造\(G(x)\)使得\(G(H^{-1}(x))=B(x)\)

\[\begin{align} G(H^{-1}(x))&=B(x)\\ G(x)&=B(H^{-1}(x))=\ln \frac{H(x)}{x} \end{align} \]

  所以我們就可以求出\(G(x)\),然后在\(O(n\log n)\)內求出\([x^n]B(x)=[x^n]G(F^{-1}(x))\)了。

  因為\(m\)\(\leq {10}^5\)的,總復雜度就是\(O(m\log m)\)

  記

\[[x^n]A(x)=[x^n]\sum_{(i-1)\in S}\frac{b_i}{i!}x^i \]

\(C(x)\)為滿足題目要求的帶標號有根簡單無向圖個數的EGF,那么滿足

\[C(x)=xe^{A(C(x))} \]

  再做一次拉格朗日反演就可以得到\([x^n]C(x)\)了。

\[[x^n]C(x)=\frac{1}{n}[x^{-1}]\frac{1}{{(\frac{x}{e^{A(x)}})}^n}=\frac{1}{n}[x^{n-1}]e^{nA(x)} \]

  最終答案為\((n-1)![x^n]C(x)\)。因為我們求的是EGF所以要乘以\(n!\),然后是有根變無根要除以一個\(n\)

  時間復雜度:\(O(n\log n+m\log m)\)

代碼

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int p=998244353;
const int N=300000;
const int W=262144;
ll fp(ll a,ll b){ll s=1;for(;b;b>>=1,a=a*a%p)if(b&1)s=s*a%p;return s;}
int iv[N];
int ifac[N];
int fac[N];
int w[W];
void ntt(int *a,int n,int t)
{
	static int rev[N];
	for(int i=1;i<n;i++)
	{
		rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
		if(rev[i]>i)
			swap(a[i],a[rev[i]]);
	}
	for(int i=2;i<=n;i<<=1)
		for(int j=0;j<n;j+=i)
			for(int k=0;k<i/2;k++)
			{
				int u=a[j+k];
				int v=(ll)a[j+k+i/2]*w[W/i*k]%p;
				a[j+k]=(u+v)%p;
				a[j+k+i/2]=(u-v)%p;
			}
	if(t==-1)
	{
		reverse(a+1,a+n);
		ll invn=fp(n,p-2);
		for(int i=0;i<n;i++)
			a[i]=a[i]*invn%p;
	}
}
void copy(int *a,int *b,int l,int r){memcpy(a+l,b+l,sizeof(a[0])*(r-l));}
void clear(int *a,int l,int r){memset(a+l,0,sizeof(a[0])*(r-l));}
void mul(int *a,int *b,int *c,int n,int m,int l=-1)
{
	static int a1[N],a2[N];
	if(l==-1)
		l=n+m;
	n=min(n,l);
	m=min(m,l);
	int k=1;
	while(k<=n+m)
		k<<=1;
	copy(a1,a,0,n+1);
	clear(a1,n+1,k);
	copy(a2,b,0,m+1);
	clear(a2,m+1,k);
	ntt(a1,k,1);
	ntt(a2,k,1);
	for(int i=0;i<k;i++)
		a1[i]=(ll)a1[i]*a2[i]%p;
	ntt(a1,k,-1);
	copy(c,a1,0,l+1);
}
void mul2(int *a,int *b,int *c,int n)
{
	mul(a,b,c,n-1,n-1,n-1);
}
void inv(int *a,int *b,int n)
{
	if(n==1)
	{
		b[0]=fp(a[0],p-2);
		return;
	}
	inv(a,b,n>>1);
	static int a1[N],a2[N];
	copy(a1,a,0,n);
	clear(a1,n,n<<1);
	copy(a2,b,0,n>>1);
	clear(a2,n>>1,n<<1);
	ntt(a1,n<<1,1);
	ntt(a2,n<<1,1);
	for(int i=0;i<n<<1;i++)
		a1[i]=a2[i]*(2-(ll)a1[i]*a2[i]%p)%p;
	ntt(a1,n<<1,-1);
	copy(b,a1,0,n);
}
void ln(int *a,int *b,int n)
{
	static int a1[N],a2[N];
	for(int i=1;i<n;i++)
		a1[i-1]=(ll)a[i]*i%p;
	a1[n-1]=0;
	inv(a,a2,n);
	mul2(a1,a2,a1,n);
	for(int i=1;i<n;i++)
		b[i]=(ll)a1[i-1]*iv[i]%p;
	b[0]=0;
}
void exp(int *a,int *b,int n)
{
	if(n==1)
	{
		b[0]=1;
		return;
	}
	exp(a,b,n>>1);
	static int a1[N],a2[N];
	clear(b,n>>1,n);
	ln(b,a1,n);
	for(int i=0;i<n>>1;i++)
		a1[i]=(a[i+(n>>1)]-a1[i+(n>>1)])%p;
	mul2(a1,b,a2,n>>1);
	for(int i=0;i<n>>1;i++)
		b[i+(n>>1)]=a2[i];
}
void pow(int *a,int *b,int n,int m)
{
	static int a1[N],a2[N],a3[N];
	int k=1;
	while(k<=n)
		k<<=1;
	copy(a1,a,0,n+1);
	clear(a1,n+1,k);
	ln(a1,a2,k);
	for(int i=0;i<k;i++)
		a2[i]=(ll)a2[i]*m%p;
	exp(a2,a3,k);
	copy(b,a3,0,n+1);
}
int a[N],b[N],g[N],h[N],f[N];
int n,m;
void geth()
{
	int k=1;
	while(k<=n+2)
		k<<=1;
	for(int i=0;i<=n+2;i++)
		f[i]=fp(2,ll(i-1)*i/2)*ifac[i]%p;
	ln(f,h,k);
	for(int i=0;i<=n+1;i++)
		h[i]=(ll)h[i]*i%p;
}
void getg()
{
	static int a1[N];
	for(int i=0;i<=n;i++)
		a1[i]=h[i+1];
	int k=1;
	while(k<=n)
		k<<=1;
	ln(a1,g,k);
	for(int i=1;i<=n;i++)
		g[i-1]=(ll)g[i]*i%p;
}
int getb(int x)
{
	int k=1;
	while(k<x)
		k<<=1;
	static int a1[N],a2[N];;
	for(int i=0;i<x;i++)
		a1[i]=(ll)h[i]*(-x)%p;
	exp(a1,a2,k);
	mul(g,a2,a1,x-1,x-1,x-1);
	return (ll)a1[x-1]*iv[x]%p;
}
int geta()
{
	static int a1[N];
	int k=1;
	while(k<n)
		k<<=1;
	for(int i=0;i<k;i++)
		b[i]=(ll)b[i]*n%p;
	exp(b,a1,k);
	return (ll)a1[n-1]*iv[n]%p;
}
int c[N];
int main()
{
#ifndef ONLINE_JUDGE
	freopen("d.in","r",stdin);
	freopen("d.out","w",stdout);
#endif
	fac[0]=fac[1]=ifac[0]=ifac[1]=iv[1]=1;
	for(int i=2;i<=W;i++)
	{
		iv[i]=(ll)-p/i*iv[p%i]%p;
		ifac[i]=(ll)ifac[i-1]*iv[i]%p;
		fac[i]=(ll)fac[i-1]*i%p;
	}
	ll w1=fp(3,(p-1)/W);
	w[0]=1;
	for(int i=1;i<W;i++)
		w[i]=w[i-1]*w1%p;
	scanf("%d%d",&n,&m);
	geth();
	getg();
	int x;
	int k=1;
	for(int i=1;i<=m;i++)
	{
		scanf("%d",&c[i]);
		c[i]--;
		while(k<c[i])
			k<<=1;
	}
	for(int i=0;i<k;i++)
		h[i]=h[i+1];
	ln(h,h,k);
	for(int i=1;i<=m;i++)
		b[c[i]]=getb(c[i]);
	ll ans=geta();
	ans=ans*fac[n-1]%p;
	ans=(ans+p)%p;
	printf("%lld\n",ans);
	return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM