再探快速傅里葉變換(FFT)學習筆記(其三)(循環卷積的Bluestein算法+分治FFT+FFT的優化+任意模數NTT)


再探快速傅里葉變換(FFT)學習筆記(其三)(循環卷積的Bluestein算法+分治FFT+FFT的優化+任意模數NTT)

8718367adab44aedf83ea643bf1c8701a18bfb21.jpg

寫在前面

為了不使篇幅過長,預計將把基於論文的學習筆記分為三部分:

  1. DFT,IDFT,FFT的定義,實現與證明:快速傅里葉變換(FFT)學習筆記(其一)
  2. NTT的實現與證明:快速傅里葉變換(FFT)學習筆記(其二)
  3. 任意模數NTT與FFT的優化技巧

一些約定

  1. \([p(x)]=\begin{cases}1,p(x)為真 \\ 0,p(x)為假 \end{cases}\)
  2. 本文中序列的下標從0開始
  3. \(s\)是一個序列,\(|s|\)表示\(s\)的長度
  4. 若大寫字母如\(F(x)\)表示一個多項式,那么對應的小寫字母如\(f\)表示多項式的每一項系數,即\(F(x)=\sum_{i=0}^{n-1} f_ix^i\)

循環卷積

DFT卷積的本質

考慮在(其一)中提到的卷積的定義式。

\[c_{r}=\sum_{p, q}[(p+q) \bmod n=r] a_{p} b_{q} \tag{1.1} \]

我們一般做FFT時忽略了式子中的\(\bmod\),其實它是在\(\bmod 2^q\)的意義下的循環卷積,只是因為\(|a|,|b|,|c|<2^q\),所以取不取模都沒什么影響。

如果序列長度\(n\)是2的整數次冪,那么直接做就可以了。

如果序列長度\(n\)不是2的整數次冪考慮暴力的做法:先做一次普通FFT,再把\(c_{k+n}\)加到\(c_k\)上。但是這樣在做多次FFT時就必須一次一次做,比如多項式快速冪。下面給出了一種在\(O(n \log n)\)的時間內實現任意長度循環卷積的算法:Bluestein’s Algorithm

Bluestein’s Algorithm

注:原論文的推導可能有誤

考慮DFT的式子

\[\begin{aligned} a'_i&=\sum_{j=0}^{n-1} a_j \omega_n^{ij} \\&=\sum_{j=0}^{n-1} a_j \omega_n^{\frac{-(i-j)^2+i^2+j^2}{2}} \\&= \omega_n^{\frac{i^2}{2}} \sum_{j=0}^{n-1}a_j \omega_n^{\frac{j^2}{2}} \omega_n^{-\frac{(i-j)^2}{2}}\end{aligned} \]

不妨設

\(x_j=a_j \omega_n^{\frac{j^2}{2}}=a_j(\cos\frac{j^2\pi}{n}+ \text{i}\sin{\frac{j^2\pi}{n}})\)

\(y_j=\omega_n^{-\frac{j^2}{2}}= \cos \frac{\pi j^2}{n}-\text{i}\sin \frac{\pi j^2}{n}\)

那么\(a_i'=\omega_n^{\frac{j^2}{2}}\sum_{j=0}^{n-1} x_j y_{i-j}\)

這已經很類似卷積的形式了,但是注意到\(j\)的上界是\(n-1\)而不是\(i\),\(j-i\)可能為負數。那么我們把\(y\)數組的長度擴大到\(2n\),定義:

\(y_j=\omega_n^{-\frac{(j-n)^2}{2}}= \cos \frac{\pi (j-n)^2}{n}-\text{i}\sin \frac{\pi (j-n)^2}{n}\).

這樣\(j<n\)的時候就對應了\(j-i\)為負數的情形,\(j\geq n\)就對應了\(j-i\)為正的情形。然后對\(x\)\(y\)用一般的FFT,最后的答案存儲在\(i+n\)的位置上,也就是說真正的\(a'_i\)實際上對應了乘積結果的\((x \cdot y)_{i+n}\)

這樣,我們就只做了3次FFT就求出了任意長度循環DFT。逆變換同理,只是換成共軛復數。注意到在上述的推導中我們沒有用到單位根\(\omega\)的任何性質,因此這里的\(\omega\)可以換成任意復數\(z\),這樣的變換稱為Chirp Z-Transform,CZT.可見,CZT實際上是DFT的廣義形式。

代碼實現:

//com是手寫復數類,省略
void fft(com *x,int *rev,int n,int type){
	//為節約篇幅,fft部分省略,x為系數序列,rev為反轉數組,n為長度,type=1表示DFT,type=-1表示IDFT
} 
void bluestein(com *a,int n,int type){ 
    //a為系數序列,n為長度,type=1表示DFT,type=-1表示IDFT
	static com x[maxn*4+5],y[maxn*4+5];
	static int rev[maxn*4+5];
	memset(x,0,sizeof(x));
	memset(y,0,sizeof(y));
    //FFT前的預處理
	int N=1,L=0;
	while(N<n*4){
		L++;
		N*=2;
	}
	for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    //x[i],y[i]的定義見上式
	for(int i=0;i<n;i++) x[i]=com(cos(pi*i*i/n),type*sin(pi*i*i/n))*a[i];
	for(int i=0;i<n*2;i++) y[i]=com(cos(pi*(i-n)*(i-n)/n),-type*sin(pi*(i-n)*(i-n)/n));
	fft(x,rev,N,1);
	fft(y,rev,N,1);
	for(int i=0;i<N;i++) x[i]*=y[i];
	fft(x,rev,N,-1);
	for(int i=0;i<n;i++){
		a[i]=x[i+n]*com(cos(pi*i*i/n),type*sin(pi*i*i/n));//記得乘上常數
		if(type==-1) a[i]/=n;//一定記得除以n,因為做一次Bluestein相當於一次FFT,IFFT最后要除n,這里也要除n 
	} 
}

例題

[POJ 2821]TN's Kindom III(任意長度循環卷積的Bluestein算法)

分治FFT

一般我們用FFT的時候,序列的所有元素都已知。但是,如果序列本身是根據卷積定義的,就無法直接套FFT

舉一個最簡單的例子\(f_i =\sum_{j=1}^i f_{i-j}g_j\).其中\(g\)給定,求\(f\). 由於我們卷積的時后后面的數基於前面的數,無法快速計算,時間復雜度退化到\(O(n^2)\). (雖然這個式子可以用(其四)中將會提到的多項式求逆解決,但是分治FFT更通用,可以處理很復雜的式子)

考慮分治: 設當前分治區間為\([l,r]\),假設我們求出了\([l,mid]\)的答案,那么可以求出這些點對\([mid+1,r]\)的影響。那么右半邊的點\(x \in [mid+1,r]\)得到的貢獻是\(\Delta_x=\sum_{i=l}^{mid} f_i g_{x-i}\).只需要把下標偏移一下(如\([l,mid]\)偏移成\([0,mid-l]\),就是一個卷積的形式,可以運用FFT或NTT計算,計算完之后,把答案累加到數組上.

偽代碼如下:

poly f,g;//上述的f,g
procedure calc(L,mid,R){
	for i in [L,mid] : a[i-L] <- f[i]//下標偏移
	for i in [1,R-L] : b[i-1] <- g[i]
	a <- mul(a,b);//fft或ntt做多項式乘法
	for i in [mid+1,R] f[i] <- f[i]+a[i-l-1]//累加貢獻
}
procedure solve(l,mid){
	if(l==r) return;
	mid <- (l+r)/2
	solve(l,mid);
	calc(l,mid,r);
	solve(mid+1,r)
}

時間復雜度分析:

\(T(n)=2T(\frac{n}{2})+n \log_2n\), 總復雜度\(\Theta(n \log^2n)\)

下面是基於NTT的模板代碼(Luogu 4721)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath> 
#define maxn 300000
#define G 3
#define invG 332748118
#define inv2 499122177
#define mod 998244353
using namespace std;
typedef long long ll;
inline ll fast_pow(ll x,ll k){
	ll ans=1;
	while(k){
		if(k&1) ans=ans*x%mod;
		x=x*x%mod;
		k>>=1;
	}
	return ans;
}
inline ll inv(ll x){
	return fast_pow(x,mod-2); 
}

void NTT(ll *x,int n,int type){
	static int rev[maxn+5];
	int tn=1;
	int k=0;
	while(tn<n){
		tn*=2;
		k++;
	}
	for(int i=0;i<tn;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
	for(int i=0;i<n;i++){
		if(i<rev[i]) swap(x[i],x[rev[i]]);
	} 
	for(int len=1;len<n;len*=2){
		int sz=len*2;
		ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz);
		for(int l=0;l<n;l+=sz){
			int r=l+len-1;
			ll gnk=1;
			for(int i=l;i<=r;i++){
				ll tmp=x[i+len];
				x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
				x[i]=(x[i]+gnk*tmp%mod)%mod;
				gnk=gnk*gn1%mod;
			} 
		} 
	}
	if(type==-1){
		int invsz=inv(n);
		for(int i=0;i<n;i++) x[i]=x[i]*invsz%mod; 
	}
}
void mul(ll *a,ll *b,ll *ans,int sz){
	NTT(a,sz,1);
	NTT(b,sz,1);
	for(int i=0;i<sz;i++) ans[i]=a[i]*b[i]%mod;
	NTT(ans,sz,-1);
} 


void cdq_divide(ll *f,ll *g,int l,int r){
	static ll tmpa[maxn+5],tmpb[maxn+5];
	if(l==r) return; 
	int mid=(l+r)>>1;
	cdq_divide(f,g,l,mid);
	int tn=1,k=0;
	while(tn<r-l){
		k++;
		tn*=2; 
	}
	for(int i=0;i<tn;i++) tmpa[i]=tmpb[i]=0; 
	for(int i=l;i<=mid;i++) tmpa[i-l]=f[i];
	for(int i=1;i<=r-l;i++) tmpb[i-1]=g[i];
	mul(tmpa,tmpb,tmpa,tn);
	for(int i=mid+1;i<=r;i++) f[i]=(f[i]+tmpa[i-l-1])%mod;
	cdq_divide(f,g,mid+1,r);
}

int n;
ll f[maxn+5],g[maxn+5];
int main(){
	scanf("%d",&n);
	for(int i=1;i<n;i++) scanf("%lld",&g[i]); 
	f[0]=1;
	cdq_divide(f,g,0,n-1);
	for(int i=0;i<n;i++) printf("%lld ",f[i]); 
} 

容易發現,許多dp方程都有分治FFT的形式。對於此類dp方程,我們可以用分治FFT將轉移復雜度由\(O(n^2)\)降到\(O(n \log^2 n)\)

例題

[Codeforces 553E]Kyoya and Train(期望DP+Floyd+分治FFT)

FFT的弱常數優化

下面介紹一些優化FFT的常數的技巧。雖然這些技巧都只是對FFT的一些小優化,但是在某些題目中優化效果極其明顯。

復雜算式中減少FFT次數

如果我們要計算一個復雜的多項式,如\(A(x)=B(x)C(x)+D(x)E(x)\)

最簡單的方法是分別計算\(B(x)C(x)\)\(D(x)E(x)\),這樣需要做6次FFT. 但是如果先對\(B,C,D,E\)做DFT,然后直接用點值表達式計算\(a_i=b_ic_i+d_ie_i\),再把\(a\)IDFT回去。這樣只需要做5次FFT,且多項式越復雜,這樣的常數就越優秀。

例題

[BZOJ 3771] Triple(FFT+容斥原理+生成函數)

利用循環卷積

考慮對於兩個長度為\(n\)的序列\(a,b\),計算它們的卷積\(c\)的第\(0.5n\)項到第\(1.5n\)項。傳統的方法是補0擴充到\(2n\)的序列。但是因為FFT求得實際上是我們已經提到過的循環卷積,所以如果只補0到\(1.5n\)(上取整),對第\(0.5n\)項到第\(1.5n\)項無影響

在基於牛頓迭代的算法中,能起到較明顯的優化作用。會在(其四)中詳細介紹這些算法。

小范圍暴力

由於FFT的常數較大。在數據范圍較小的時候甚至不如\(O(n^2)\)的暴力卷積的優秀。因此在做多次FFT和分治FFT的時候,如果當前的序列長度較小,可以采用暴力算法。

例題

[BZOJ 3509] [CodeChef] COUNTARI (FFT+分塊)

快速冪乘法次數的優化

這個東西實際上比較雞肋。因為多項式快速冪可以通過多項式\(\ln\)\(\exp\)優化到\(O(n \log n)\).但是為了應對考場上時間不夠的情況,我們來考慮如何通過簡單的實現來減少\(O(n \log^2n)\)的倍增快速冪的復雜度。

倍增法的思路是根據前面算過的乘積快速算出當前的乘積,如\(1 \to 2 \to 4 \to 8\).最壞情況下需要\(2 \log_2n+C\)次乘法。但這並不是下界。我們定義additional chain為一條鏈,最開始是1,后一個數減前一個數的差是鏈上這個是前面的某一個數。例如\(1 \to 2 \to 4 \to 6\).\(6-4=2\)在前面出現過,\(4-2=2\)在前面出現過。那么根據這條additional chain計算6次冪的時候,可以從1次冪出發,用1次冪乘1次冪得到2次冪,再乘2次冪得到4次冪,再乘2次冪得到6次冪。

很可惜,對於數\(k\)求出得到\(k\)的最短additional chain是NP-hard的。但是有很好的近似算法。近似算法基於BFS。每次我們對於隊頭的數\(x\),枚舉它對應的additional chain中的數\(y\),如果\(x+y\)還沒有訪問過那么將其入隊,並將\(x\)對應的鏈后面接上\(x+y\). 這個預處理是\(O(k)\)的,且對快速冪的常數優化很顯著。

如果\(k\)很大,比如\(10^{10000}\),可以采用十進制快速冪。但是用Method of Four Russians(俗稱四毛子算法),可以將乘法次數減少到\(\log_2n+O(\frac{\log n}{\log \log n})\).具體方法見2017年國家集訓隊論文《非常規大小分塊算法初探》

FFT的強常數優化

FFT的強常數優化一般是通過減少FFT次數來實現的
在這一節中,我們記\(DFT(A(x))\)表示多項式\(A(x)\)(或序列)做DFT之后的結果,\(IDFT(A(x))\)同理

我們現在考慮最常見的一個模型:給出兩個長度為\(n+1\)\(m+1\)的多項式\(A(x),B(x)\),我們要計算他們的線性卷積。假設長度已經補齊為第一個大於\(n+m+1\)的2的整數冪\(L\)

顯然直接搞需要3次長度為\(L\)的FFT。毒瘤的Vladimir Smykalov在cf上最先給出了這個問題的優化算法。

DFT的合並

DFT的合並是指,對於兩個序列\(a\),\(b\),我們只通過一次FFT就求出\(DFT(a),DFT(b)\)

不妨設:

\[P(x)=A(x)+\text{i}B(x) \tag{4.1} \]

\[Q(x)=A(x)-\text{i}B(x) \tag{4.2} \]

接下來我們開始推導公式。注意為了簡潔,我們記\(X=\frac{2 \pi jk}{2L}\),\(\text{conj}(z)\)表示\(z\)的共軛復數

\[\begin{aligned} DFT(p_k) &=A\left(\omega_{2 L}^{k}\right)+i B\left(\omega_{2 L}^{k}\right) \\ &=\sum_{j=0}^{2 L-1} a_{j} \omega_{2 L}^{j k}+i b_{j} \omega_{2 L}^{j k} \\ &=\sum_{j=0}^{2 L-1}\left(a_{j}+i b_{j}\right)(\cos X+i \sin X) \end{aligned}\]

\[\begin{aligned} DFT(q_k) &=A\left(\omega_{2 L}^{k}\right)-i B\left(\omega_{2 L}^{k}\right) \\ &=\sum_{j=0}^{2 L-1} a_{j} \omega_{2 L}^{j k}-i b_{j} \omega_{2 L}^{j k} \\ &=\sum_{j=0}^{2 L-1}\left(a_{j}-i b_{j}\right)(\cos X+i \sin X) \\ &=\sum_{j=0}^{2 L-1}\left(a_{j} \cos X+b_{j} \sin X+i \sin X-b_{j} \cos X\right) \\&=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j} \cos X+b_{j} \sin X\right)-i\left(a_{j} \sin X-b_{j} \cos X\right)\right)\\ &=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j} \cos (-X)-b_{j} \sin (-X)\right)+i\left(a_{j} \sin (-X)+b_{j} \cos (-X)\right)\right)\\ &=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j}+i b_{j}\right)(\cos (-X)+i \sin (-X))\right)\\ &=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j}+i b_{j}\right) \omega_{2 i}^{-j k}\right)\\ &=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j}+i b_{j}\right) \omega_{2 L}^{(2 L-k) j}\right)\\ &=\operatorname{conj}\left(p'[2 L-k]\right) \end{aligned}\]

也就是說,只要一次DFT算出\(DFT(p)\),就可以把序列反轉再取共軛復數得到\(DFT(q)\).

由於DFT是線性變換,

\[DFT(a_k)=\frac{DFT(p_k)+DFT(q_k)}{2}=\frac{DFT(p_k)+\text{conj}(DFT(p_j))}{2} \]

其中\(j\)\(k\)翻轉后的數,即\(j=\begin{cases}0,k=0 \\ L-k ,k>0 \end{cases}\)

又由\((4.1),(4.2)\)

\[DFT(a_k)=\frac{DFT(p_k)+DFT(q_k)}{2} \tag{4.3} \]

\[DFT(b_k)=-\text{i}\frac{DFT(p_k)-DFT(q_k)}{2} \tag{4.4} \]

\[DFT(a_k)DFT(b_k)=\text{i}\frac{{DFT(p_k)}^2-{DFT(q_k)}^2}{4} \tag{4.5} \]

這樣我們就可以從\(q'\)推出\(a',b'\),也就是說一次DFT就能得到\(a'\)\(b'\)了.

我們一共做了2次長度為\(L\)的FFT.

代碼(UOJ#34):

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std; 
typedef long long ll;
struct com{
	double real;
	double imag;
	com(){
		
	} 
	com(double _real,double _imag){
		real=_real;
		imag=_imag;
	}
	com(double x){
		real=x;
		imag=0;
	}
	void operator = (const com x){
		this->real=x.real;
		this->imag=x.imag;
	}
	void operator = (const double x){
		this->real=x;
		this->imag=0;
	}
	friend com operator + (com p,com q){
		return com(p.real+q.real,p.imag+q.imag);
	}
	friend com operator + (com p,double q){
		return com(p.real+q,p.imag);
	}
	void operator += (com q){
		*this=*this+q;
	}
	void operator += (double q){
		*this=*this+q;
	}
	friend com operator - (com p,com q){
		return com(p.real-q.real,p.imag-q.imag);
	}
	friend com operator - (com p,double q){
		return com(p.real-q,p.imag);
	}
	void operator -= (com q){
		*this=*this-q;
	}
	void operator -= (double q){
		*this=*this-q;
	}
	friend com operator * (com p,com q){
		return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
	}
	friend com operator * (com p,double q){
		return com(p.real*q,p.imag*q);
	} 
	void operator *= (com q){
		*this=(*this)*q;
	}
	void operator *= (double q){
		*this=(*this)*q;
	}
	friend com operator / (com p,double q){
		return com(p.real/q,p.imag/q);
	} 
	void operator /= (double q){
		*this=(*this)/q;
	} 
	com conj(){
		return com(real,-imag);
	}
	void print(){
		printf("%lf + %lf i ",real,imag);
	}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
	for(int len=1;len<n;len*=2){
		int sz=len*2;
		for(int l=0;l<n;l+=sz){
			int r=l+len-1;
			for(int i=l;i<=r;i++){
				com tmp=x[i+len];
				x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
				x[i]=x[i]+tmp*w[n/sz*(i-l)];
			}
		}
	}
}
void mul(ll *a,ll *b,ll *c,int n){
	static com p[maxn+5],r[maxn+5];
	for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));//預處理單位根 
	for(int i=0;i<n;i++) p[i]=com(a[i],b[i]);//p[i]=a[i]+ib[i]
	fft(p,n);
	for(int i=0;i<n;i++){
		int j=(i>0?(n-i):0);//0的位置需要特判一下
		com q=p[j];
		r[j]=(p[i]*p[i]-q.conj()*q.conj())*com(0,-0.25);//按照上面的式子
	}	
	fft(r,n);//這里是用了第一篇中提到的反轉技巧
	for(int i=0;i<n;i++) c[i]=r[i].real/n+0.5;
}

int n,m; 
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
	scanf("%d %d",&n,&m);
	for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
	for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
	int N=1,L=0;
	while(N<n+m+1){
		L++;
		N*=2;
	} 
	for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
	mul(a,b,c,N);
	for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]);
}

IDFT的合並

IDFT的合並是指,對於兩個序列\(a\),\(b\),我們只通過一次FFT就求出\(IDFT(a),IDFT(b)\)

IDFT的合並非常簡單。
\(r(x)=a(x)+\text{i}b(x)\)
由於IDFT是線性變換
\(IDFT(r(x))=IDFT(a(x))+\text{i}IDFT(b(x))\)
又因為\(a(x)\)\(b(x)\)都是實數序列,那么\(IDFT(r(x))\)的實部就是\(IDFT(a(x))\),虛部就是\(IDFT(b(x))\)

形如\((A+B)(C+D)\)的卷積的優化

在這一節中我們討論\((A(x)+B(x))(C(x)+D(x))\)形式的卷積的優化.

一般的做法是對\(A,B,C,D\)都做一次DFT,然后按照這個式子直接計算,最后再IDFT回來。需要5次FFT.

而根據上面的合並技巧,先把\(A(x),B(x)\)合並DFT,\(C(x),D(x)\)合並DFT得到點值表達式.
由於\((A(x)+B(x))(C(x)+D(x))=A(x)C(x)+A(x)D(x)+B(x)C(x)+B(x)D(x)\)
我們可以直接把點值表達式相乘得到這4個多項式。對於這4個多項式,分成2組合並做IDFT即可。
總共需要4次FFT.

大致代碼如下:

void mul(ll *a,ll *b,ll *c,ll *d,ll *ans,int n){
	static com p[maxn+5],q[maxn+5];
	static com r[maxn+5],s[maxn+5];
	for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
	for(int i=0;i<n;i++){
		p[i]=com(a[i],b[i]);//打包A,B 
		q[i]=com(c[i],d[i]);//打包C,D 
	}
	fft(p,n);
	fft(q,n);
	for(int i=0;i<n;i++){
		int j=(i==0?0:n-i);
		//得到DFT(A),DFT(B),DFT(C),DFT(D) 
		com da=(p[i]+p[j].conj())*0.5;
		com db=(p[i]-p[j].conj())*com(0,-0.5);
		com dc=(q[i]+q[j].conj())*0.5;
		com dd=(q[i]-q[j].conj())*com(0,-0.5);
		r[j]=da*dc+da*dd*com(0,1);//打包AC,AD 
		s[j]=db*dc+db*dd*com(0,1); //打包BC,BD 
	}
	fft(r,n);
	fft(s,n);
	for(int i=0;i<n;i++){
		ll ac,ad,bc,bd; 
		ac=(ll)(r[i].real/n+0.5);
        ad=(ll)(r[i].imag/n+0.5);
        bc=(ll)(s[i].real/n+0.5);
        bd=(ll)(s[i].imag/n+0.5);
        ans[i]=ac+ad+bc+bd;
	}
}

卷積的終極優化

上述優化中我們只用到了DFT的思想。現在我們利用FFT的思想繼續優化

同樣拆分奇偶項,\(A(x)=A_0(x^2)+xA_1(x^2)\)

\[\begin{aligned} A(x)B(x)&=(A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\\ &=A_0(x^2)B_0(x^2)+x(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2))+x^2A_1(x^2)B_1(x^2) \end{aligned} \tag{4.6}\]

我們只需要知道上式中\(x^0,x^1,x^2\)的系數
發現\(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2)\)是奇數項的系數,\(A_0(x^2)B_0(x^2)\)\(A_1(x^2)B_1(x^2)\)是偶數項的系數,而偶數項的兩個東西都可以看成一個關於\(x^2\)的多項式。

我們先優化DFT的過程,觀察\((4.6)\)式的乘積形式\((A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\).

我們發現,這個形式和上一節的\((A+B)(C+D)\)很像,可以類似地優化。
\(p_k={a_0}_k+\text{i}{a_1}_k,q_k={b_0}_k+\text{i}{b_1}_k\)

然后合並IDFT,再設兩個輔助多項式

\[G(x)=DFT(A_0(x))\cdot DFT(B_0(x))+\omega_L^k DFT(A_1(x)) DFT(B_1(x)) \]

(注意我們把\(x^2\)換元成\(x\),做DFT的時候要乘上單位根)

\[F(x)=DFT(A_0(x))\cdot DFT(B_1(x))+ DFT(A_1(x)) DFT(B_0(x)) \]

那么我們只需要計算出\(IDFT(G(x))\)\(IDFT(F(x))\)

\(R(x)=G(x)+\mathrm{i} F(x)\)
那么因為IDFT是線性變換,\(IDFT(R(x))=IDFT(G(x))+\mathrm{i} IDFT(F(x))\)
(IDFT的線性性這里不做證明,容易發現兩個點值表達式相加再IDFT回來,顯然系數也會相加)

顯然這兩個多項式IDFT的結果是實數。故我們只要求出\(IDFT(R(x))\),每一項系數的實部就是偶數項系數\(G(x)\),虛部就是奇數項系數\(F(x)\)

我們再考慮把合並DFT弄進去,即式\((4.3)(4.4)(4.5)\)

接下來我們嘗試用\(DFT(p_k),DFT(q_k)\)來表示\(R(x)=G(x)+\text{i}F(x)\),為了推導簡潔,我們省略\(DFT\)不寫

\[\begin{aligned} g_k&=\frac {p_k+\text{conj}(p_j)}{2}\cdot \frac {q_k+\text{conj}(q_j)}{2}+\omega_L^k\cdot \frac {p_k-\text{conj}(p_j)}{-2i}\cdot \frac {q_k-\text{conj}(q_j)}{-2i}\\ &=\frac 1 4 [(p_k+\text{conj}(p_j))\cdot(q_k+\text{conj}(q_j))-\omega_L^k\cdot(p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))]\\ \\ f_k&=\frac {p_k+\text{conj}(p_j)} 2 \cdot \frac{q_k-\text{conj}(q_j)}{-2}i+\frac {q_k+\text{conj}(q_j)} 2 \cdot \frac{p_k-\text{conj}(p_j)}{-2}i\\ &=\frac i{-4}[2\cdot p_k\cdot q_k-2\cdot \text{conj}(p_j)\cdot \text{conj}(q_j)] \end{aligned}\]

那么

\[\begin{aligned} g_k+\text{i} f_k&=\frac 1 4 [(p_k+\text{conj}(p_j))\cdot(q_k+\text{conj}(q_j))-w_L^k\cdot(p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))-2\cdot p_k\cdot q_k+2 \text{conj}(p_j\cdot q_j)]\\ &=\frac 1 4 [-(p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))+2\cdot (p_k\cdot q_k+\text{conj}(p_j\cdot q_j))\\ &-w_L^k\cdot(p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))+2\cdot p_k\cdot q_k-2\cdot \text{conj}(p_j\cdot q_j)]\\ &=q_k\cdot p_k-\frac 1 4[(1+w_L^k)\cdot (p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))]\\ \end{aligned}\]

和上一節的\((A+B)(C+D)\)不同,我們只用了3次長度為\(L/2\)的FFT,就求出了答案,這是由於FFT本身的性質。因為長度縮減了一半,我們不妨稱它為\(1.5\)次FFT.

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std; 
typedef long long ll;
struct com{
	double real;
	double imag;
	com(){
		
	} 
	com(double _real,double _imag){
		real=_real;
		imag=_imag;
	}
	com(double x){
		real=x;
		imag=0;
	}
	void operator = (const com x){
		this->real=x.real;
		this->imag=x.imag;
	}
	void operator = (const double x){
		this->real=x;
		this->imag=0;
	}
	friend com operator + (com p,com q){
		return com(p.real+q.real,p.imag+q.imag);
	}
	friend com operator + (com p,double q){
		return com(p.real+q,p.imag);
	}
	void operator += (com q){
		*this=*this+q;
	}
	void operator += (double q){
		*this=*this+q;
	}
	friend com operator - (com p,com q){
		return com(p.real-q.real,p.imag-q.imag);
	}
	friend com operator - (com p,double q){
		return com(p.real-q,p.imag);
	}
	void operator -= (com q){
		*this=*this-q;
	}
	void operator -= (double q){
		*this=*this-q;
	}
	friend com operator * (com p,com q){
		return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
	}
	friend com operator * (com p,double q){
		return com(p.real*q,p.imag*q);
	} 
	void operator *= (com q){
		*this=(*this)*q;
	}
	void operator *= (double q){
		*this=(*this)*q;
	}
	friend com operator / (com p,double q){
		return com(p.real/q,p.imag/q);
	} 
	void operator /= (double q){
		*this=(*this)/q;
	} 
	com conj(){
		return com(real,-imag);
	}
	void print(){
		printf("%lf + %lf i ",real,imag);
	}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){

	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
	for(int len=1;len<n;len*=2){
		int sz=len*2;
		for(int l=0;l<n;l+=sz){
			int r=l+len-1;
			for(int i=l;i<=r;i++){
				com tmp=x[i+len];
				x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
				x[i]=x[i]+tmp*w[n/sz*(i-l)];
			}
		}
	}
}
void mul(ll *a,ll *b,ll *c,int n){
	static com p[maxn+5],q[maxn+5],r[maxn+5];
	for(int i=0;i<n;i++){//合並做DFT
		if(i%2==1){
			p[i/2].imag=a[i];
			q[i/2].imag=b[i]; 
		}else{
			p[i/2].real=a[i];
			q[i/2].real=b[i];
		}
	}
	n/=2;
	for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
	fft(q,n);
	fft(p,n);
	for(int i=0;i<n;i++){
		int j=(i>0?(n-i):0);
		r[j]=p[i]*q[i]-(w[i]+1)*(p[i]-p[j].conj())*(q[i]-q[j].conj())*0.25;
	}	
	fft(r,n);
	for(int i=0;i<n;i++){
		c[i*2]=r[i].real/n+0.5;
		c[i*2+1]=r[i].imag/n+0.5; 
	}
}

int n,m; 
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
	scanf("%d %d",&n,&m);
	for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
	for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
	int N=1,L=0;
	while(N<=n+m+1){
		L++;
		N*=2;
	} 
	for(int i=0;i<N/2;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-2));//注意這里的rev數組是對N/2做的,L要-1 
	mul(a,b,c,N);
	for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]);
}


任意模數NTT

三模數NTT

這是任意模數NTT的算法中最好理解的一種,它基於中國剩余定理。

定理5.1\(m_1,m_2 ,\dots m_n\)兩兩互質,則對於\(\forall a_1,a_2 \dots a_n\)同余方程組

\[\begin{cases} x \equiv a_1 (\bmod m_1) \\ x \equiv a_2 (\bmod m_2) \\ \dots \\ x \equiv a_n (\bmod m_n)\end{cases} \]

有整數解解,且可以用如下方式構造解

  1. \(M=\prod_{i=1}^n m_i,M_i=\frac{M}{m_i}\)
  2. \(M_i^{-1}\)為模\(m_i\)意義下\(M_i\)的逆元
  3. 則該方程組在模\(M\)意義下的唯一解為\(x=\sum_{i=1}^n a_iM_iM_i^{-1}\) ,方程組的通解可以表示為\(x+kM(k \in \mathbb{Z})\)

這就是著名的中國剩余定理(Chinese Reminder Theorem,CRT)

證明:

對於\(k \neq i\),\(a_iM_iM_i^{-1} \bmod m_k=0\), 而根據逆元的定義,\(a_iM_iM_i^{-1} \bmod m_i =a_i\). 再代入到\(\sum_{i=1}^n a_iM_iM_i^{-1}\),原方程組成立。

回到任意模數NTT問題

\(M\)意義下長度為\(n\)的序列做卷積,最大值可以到\(n^2M\).一般的題目中\(n \leq 10^5,M\leq 10^{9}\),那么結果會到\(10^{23}\)級別。用long double等存儲會丟失精度。那么我們可以選三個乘起來大於\(10^{23}\)的NTT模數998244353,1004535809,469762049(選這三個模數的好處是他們的原根都是3,所以NTT部分寫起來比較簡潔)。然后分別在這三個模數的意義下做卷積。最后考慮把答案合並,我們只考慮某一位上的值\(ans\),容易寫出:

\[\begin{cases} ans=a_1( \bmod m_1) (5.2)\\ans=a_2( \bmod m_2)(5.3)\\ans=a_3( \bmod m_3) (5.4)\end{cases} \]

顯然\(m_1,m_2,m_3\)互質,那么我們可以利用中國剩余定理直接合並。但是,直接合並把三個模數乘起來的時候會超出long long的范圍。注意到兩個模數相乘還是在long long范圍內的,可以兩兩合並,具體方法如下,

\(inv(a,m)\)表示\(a\)在模\(m\)下的逆元.根據CRT合並\((5.2)(5.3)\)有:

\[ans \equiv a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2)(\bmod m_1m_2) \tag{5.5} \]

不妨設\(ans=km_1m_2+r\),根據\(5.4\)

\(ans=km_1 m_2+r=q m_3+a_3 \tag{5.6}\),

在模 \(m_3\) 意義下有

\(km_1 m_2+r \equiv a_3 (\bmod m_3) \tag{5.7}\)

因此\(k=(a_3-r_2)inv(m_1m_2,m_3) (\bmod m_3)\),不妨設\(k=dm_3+e\),代入\(5.6\)

\[ans=dm_1m_2m_3+em_1m_2+r \]

由於\(m_1m_2m_3>ans\),所以\(d=0\),也就是說,\(ans=em_1m_2+r\),其中\(r=a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2),e=(a_3-r_2)inv(m_1m_2,m_3)\)

const ll mm=m1*m2;
inline ll inv(ll a,ll m);
ll mul(ll a,ll b,ll m);//要用按位乘防止溢出
ll CRT(ll a1,ll a2,ll a3){
    ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
    ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
    return ((e%C)*(mm%C)%C+r%C)%C;
}

完整代碼(LuoguP4245 【模板】任意模數NTT)

#include<iostream>
#include<cstdio>
#include<cstring>
#define m1 998244353ll
#define m2 1004535809ll
#define m3 469762049ll
#define G 3
#define maxn 1048576
using namespace std; 
typedef long long ll;
const ll mm=m1*m2;
ll C;
ll fast_pow(ll x,ll k,ll m){
	ll ans=1;
	while(k){
		if(k&1) ans=ans*x%m;
		x=x*x%m;
		k>>=1; 
	}
	return ans;
}
inline ll inv(ll a,ll m){
	return fast_pow(a%m,m-2,m); //一定要取模m 
} 

ll mul(ll a,ll b,ll m){
	ll ans=0;
	while(b){
		if(b&1) ans=(ans+a)%m;
		a=(a+a)%m;
		b>>=1;
	}
	return ans;
}
ll CRT(ll a1,ll a2,ll a3){
	//[Warning]You are not expected to understand this.
    ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
    ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
    return ((e%C)*(mm%C)%C+r%C)%C;
}

int n,m,N,L;
int rev[maxn+5];
void NTT(ll *x,int n,int type,ll mod){
	ll invG=inv(G,mod); 
	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]); 
	for(int len=1;len<n;len*=2){
		int sz=len*2;
		ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz,mod);
		for(int l=0;l<n;l+=sz){
			int r=l+len-1;
			ll gnk=1;
			for(int i=l;i<=r;i++){
				ll tmp=x[i+len];
				x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
				x[i]=(x[i]+gnk*tmp%mod)%mod;
				gnk=gnk*gn1%mod; 
			}
		}
	} 
	if(type==-1){
		ll invn=inv(n,mod);
		for(int i=0;i<n;i++) x[i]=x[i]*invn%mod; 
	}
} 
void fmul(ll *a,ll *b,ll *ans,int n,ll mod){
	static ll ta[maxn+5],tb[maxn+5];
	for(int i=0;i<n;i++) ta[i]=a[i];
	for(int i=0;i<n;i++) tb[i]=b[i];
	NTT(ta,n,1,mod);
	if(a!=b) NTT(tb,n,1,mod);
	for(int i=0;i<n;i++) ans[i]=ta[i]*tb[i]%mod;
	NTT(ans,n,-1,mod);
}

ll a[maxn+5],b[maxn+5],c[3][maxn+5];
int main(){
	scanf("%d %d %lld",&n,&m,&C);
	for(int i=0;i<=n;i++){
		scanf("%lld",&a[i]);
		a[i]%=C;
	}
	for(int i=0;i<=m;i++){
		scanf("%lld",&b[i]);
		b[i]%=C;
	}
	N=1,L=0;
	while(N<n+m+1){
		N*=2;
		L++;
	}
	for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
	fmul(a,b,c[0],N,m1);
	fmul(a,b,c[1],N,m2);
	fmul(a,b,c[2],N,m3);
	for(int i=0;i<n+m+1;i++){
		printf("%lld ",CRT(c[0][i],c[1][i],c[2][i]));
	}
}

容易發現,三模數NTT需要9次FFT,不是很優秀

拆系數FFT

我們之前討論的優化都是針對FFT的,那不妨嘗試用FFT解決任意模數NTT

最簡單的想法是不取模,FFT完再取模。但是上文提到數值過大,long double會丟失精度。
int128是一個方法,但在OI比賽中不一定能使用。所以需要拆系數。

\(M_0=[\sqrt{M}]\)

\[\begin{aligned} a_i=k[a_i]M_0+b[a_i]\\ b_i=k[b_i]M_0+b[b_i]\end{aligned}\]

相當於把模數換成\(M_0\),降低大小。
代入對應的多項式

\[\begin{aligned}A(x)=K_a(x)M_0+B_a(x)\\ B(x)=K_b(x)M_0+B_b(x)\\ A(x)B(x)=K_a(x)K_b(x)M_0^2+(K_a(x)B_b(x)+K_b(x)B_a(x))M_0+B_a(x)B_b(x) \end{aligned}\]

這不就是我們提到的\((A+B)(C+D)\)形的卷積嗎?
由於\(k,b\)都不超過\(2^{15}\),於是就不容易被卡精度了。實際操作中我們不必取\(M_0=\sqrt{M}\),直接取\(M_0=2^{15}\)即可。這樣取模運算可以換成位運算,進一步減小常數。

代碼(LuoguP4245 【模板】任意模數NTT)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std; 
typedef long long ll;
struct com{
	double real;
	double imag;
	com(){
		
	} 
	com(double _real,double _imag){
		real=_real;
		imag=_imag;
	}
	com(double x){
		real=x;
		imag=0;
	}
	void operator = (const com x){
		this->real=x.real;
		this->imag=x.imag;
	}
	void operator = (const double x){
		this->real=x;
		this->imag=0;
	}
	friend com operator + (com p,com q){
		return com(p.real+q.real,p.imag+q.imag);
	}
	friend com operator + (com p,double q){
		return com(p.real+q,p.imag);
	}
	void operator += (com q){
		*this=*this+q;
	}
	void operator += (double q){
		*this=*this+q;
	}
	friend com operator - (com p,com q){
		return com(p.real-q.real,p.imag-q.imag);
	}
	friend com operator - (com p,double q){
		return com(p.real-q,p.imag);
	}
	void operator -= (com q){
		*this=*this-q;
	}
	void operator -= (double q){
		*this=*this-q;
	}
	friend com operator * (com p,com q){
		return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
	}
	friend com operator * (com p,double q){
		return com(p.real*q,p.imag*q);
	} 
	void operator *= (com q){
		*this=(*this)*q;
	}
	void operator *= (double q){
		*this=(*this)*q;
	}
	friend com operator / (com p,double q){
		return com(p.real/q,p.imag/q);
	} 
	void operator /= (double q){
		*this=(*this)/q;
	} 
	com conj(){
		return com(real,-imag);
	}
	void print(){
		printf("(%lf,%lf)\n",real,imag);
	}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
	for(int len=1;len<n;len*=2){
		int sz=len*2;
		for(int l=0;l<n;l+=sz){
			int r=l+len-1;
			for(int i=l;i<=r;i++){
				com tmp=x[i+len];
				x[i+len]=x[i]-tmp*w[n/sz*(i-l)];
				x[i]=x[i]+tmp*w[n/sz*(i-l)];
			}
		}
	}
}
ll mod; 
void mul(ll *ina,ll *inb,ll *inc,int n){
	static ll a[maxn+5],b[maxn+5],c[maxn+5],d[maxn+5];
	static com p[maxn+5],q[maxn+5];
	static com r[maxn+5],s[maxn+5];
	for(int i=0;i<n;i++){
		ina[i]=(ina[i]+mod)%mod;
		inb[i]=(inb[i]+mod)%mod;
		a[i]=ina[i]>>15;
		b[i]=ina[i]&((1<<15)-1);
		c[i]=inb[i]>>15;
		d[i]=inb[i]&((1<<15)-1);
	}
	for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
	for(int i=0;i<n;i++){
		p[i]=com(a[i],b[i]);//打包A,B 
		q[i]=com(c[i],d[i]);//打包C,D 
	}
	fft(p,n);
	fft(q,n);
	for(int i=0;i<n;i++){
//		p[i].print();
		int j=(i==0?0:n-i);
		//得到DFT(A),DFT(B),DFT(C),DFT(D) 
		com da=(p[i]+p[j].conj())*0.5;
		com db=(p[i]-p[j].conj())*com(0,-0.5);
		com dc=(q[i]+q[j].conj())*0.5;
		com dd=(q[i]-q[j].conj())*com(0,-0.5);
		r[j]=da*dc+da*dd*com(0,1);//打包AC,AD 
		s[j]=db*dc+db*dd*com(0,1); //打包BC,BD 
	}
	fft(r,n);
	fft(s,n);
	for(int i=0;i<n;i++){
		ll ac,ad,bc,bd; 
		ac=(ll)(r[i].real/n+0.5)%mod;
        ad=(ll)(r[i].imag/n+0.5)%mod;
        bc=(ll)(s[i].real/n+0.5)%mod;
        bd=(ll)(s[i].imag/n+0.5)%mod;
        inc[i]=((ac<<30)+((ad+bc)<<15)+bd)%mod;
	}
}

int n,m; 
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
	scanf("%d %d %lld",&n,&m,&mod);
	for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
	for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
	int N=1,L=0;
	while(N<=n+m+1){
		L++;
		N*=2;
	} 
	for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
	mul(a,b,c,N);
	for(int i=0;i<n+m+1;i++) printf("%lld ",c[i]);
}

更簡潔的實現


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM