再探快速傅里葉變換(FFT)學習筆記(其三)(循環卷積的Bluestein算法+分治FFT+FFT的優化+任意模數NTT)
寫在前面
為了不使篇幅過長,預計將把基於論文的學習筆記分為三部分:
- DFT,IDFT,FFT的定義,實現與證明:快速傅里葉變換(FFT)學習筆記(其一)
- NTT的實現與證明:快速傅里葉變換(FFT)學習筆記(其二)
- 任意模數NTT與FFT的優化技巧
一些約定
- \([p(x)]=\begin{cases}1,p(x)為真 \\ 0,p(x)為假 \end{cases}\)
- 本文中序列的下標從0開始
- 若\(s\)是一個序列,\(|s|\)表示\(s\)的長度
- 若大寫字母如\(F(x)\)表示一個多項式,那么對應的小寫字母如\(f\)表示多項式的每一項系數,即\(F(x)=\sum_{i=0}^{n-1} f_ix^i\)
循環卷積
DFT卷積的本質
考慮在(其一)中提到的卷積的定義式。
我們一般做FFT時忽略了式子中的\(\bmod\),其實它是在\(\bmod 2^q\)的意義下的循環卷積,只是因為\(|a|,|b|,|c|<2^q\),所以取不取模都沒什么影響。
如果序列長度\(n\)是2的整數次冪,那么直接做就可以了。
如果序列長度\(n\)不是2的整數次冪考慮暴力的做法:先做一次普通FFT,再把\(c_{k+n}\)加到\(c_k\)上。但是這樣在做多次FFT時就必須一次一次做,比如多項式快速冪。下面給出了一種在\(O(n \log n)\)的時間內實現任意長度循環卷積的算法:Bluestein’s Algorithm
Bluestein’s Algorithm
注:原論文的推導可能有誤
考慮DFT的式子
不妨設
\(x_j=a_j \omega_n^{\frac{j^2}{2}}=a_j(\cos\frac{j^2\pi}{n}+ \text{i}\sin{\frac{j^2\pi}{n}})\)
\(y_j=\omega_n^{-\frac{j^2}{2}}= \cos \frac{\pi j^2}{n}-\text{i}\sin \frac{\pi j^2}{n}\)
那么\(a_i'=\omega_n^{\frac{j^2}{2}}\sum_{j=0}^{n-1} x_j y_{i-j}\)
這已經很類似卷積的形式了,但是注意到\(j\)的上界是\(n-1\)而不是\(i\),\(j-i\)可能為負數。那么我們把\(y\)數組的長度擴大到\(2n\),定義:
\(y_j=\omega_n^{-\frac{(j-n)^2}{2}}= \cos \frac{\pi (j-n)^2}{n}-\text{i}\sin \frac{\pi (j-n)^2}{n}\).
這樣\(j<n\)的時候就對應了\(j-i\)為負數的情形,\(j\geq n\)就對應了\(j-i\)為正的情形。然后對\(x\)和\(y\)用一般的FFT,最后的答案存儲在\(i+n\)的位置上,也就是說真正的\(a'_i\)實際上對應了乘積結果的\((x \cdot y)_{i+n}\)
這樣,我們就只做了3次FFT就求出了任意長度循環DFT。逆變換同理,只是換成共軛復數。注意到在上述的推導中我們沒有用到單位根\(\omega\)的任何性質,因此這里的\(\omega\)可以換成任意復數\(z\),這樣的變換稱為Chirp Z-Transform,CZT.可見,CZT實際上是DFT的廣義形式。
代碼實現:
//com是手寫復數類,省略
void fft(com *x,int *rev,int n,int type){
//為節約篇幅,fft部分省略,x為系數序列,rev為反轉數組,n為長度,type=1表示DFT,type=-1表示IDFT
}
void bluestein(com *a,int n,int type){
//a為系數序列,n為長度,type=1表示DFT,type=-1表示IDFT
static com x[maxn*4+5],y[maxn*4+5];
static int rev[maxn*4+5];
memset(x,0,sizeof(x));
memset(y,0,sizeof(y));
//FFT前的預處理
int N=1,L=0;
while(N<n*4){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
//x[i],y[i]的定義見上式
for(int i=0;i<n;i++) x[i]=com(cos(pi*i*i/n),type*sin(pi*i*i/n))*a[i];
for(int i=0;i<n*2;i++) y[i]=com(cos(pi*(i-n)*(i-n)/n),-type*sin(pi*(i-n)*(i-n)/n));
fft(x,rev,N,1);
fft(y,rev,N,1);
for(int i=0;i<N;i++) x[i]*=y[i];
fft(x,rev,N,-1);
for(int i=0;i<n;i++){
a[i]=x[i+n]*com(cos(pi*i*i/n),type*sin(pi*i*i/n));//記得乘上常數
if(type==-1) a[i]/=n;//一定記得除以n,因為做一次Bluestein相當於一次FFT,IFFT最后要除n,這里也要除n
}
}
例題
[POJ 2821]TN's Kindom III(任意長度循環卷積的Bluestein算法)
分治FFT
一般我們用FFT的時候,序列的所有元素都已知。但是,如果序列本身是根據卷積定義的,就無法直接套FFT
舉一個最簡單的例子\(f_i =\sum_{j=1}^i f_{i-j}g_j\).其中\(g\)給定,求\(f\). 由於我們卷積的時后后面的數基於前面的數,無法快速計算,時間復雜度退化到\(O(n^2)\). (雖然這個式子可以用(其四)中將會提到的多項式求逆解決,但是分治FFT更通用,可以處理很復雜的式子)
考慮分治: 設當前分治區間為\([l,r]\),假設我們求出了\([l,mid]\)的答案,那么可以求出這些點對\([mid+1,r]\)的影響。那么右半邊的點\(x \in [mid+1,r]\)得到的貢獻是\(\Delta_x=\sum_{i=l}^{mid} f_i g_{x-i}\).只需要把下標偏移一下(如\([l,mid]\)偏移成\([0,mid-l]\),就是一個卷積的形式,可以運用FFT或NTT計算,計算完之后,把答案累加到數組上.
偽代碼如下:
poly f,g;//上述的f,g
procedure calc(L,mid,R){
for i in [L,mid] : a[i-L] <- f[i]//下標偏移
for i in [1,R-L] : b[i-1] <- g[i]
a <- mul(a,b);//fft或ntt做多項式乘法
for i in [mid+1,R] f[i] <- f[i]+a[i-l-1]//累加貢獻
}
procedure solve(l,mid){
if(l==r) return;
mid <- (l+r)/2
solve(l,mid);
calc(l,mid,r);
solve(mid+1,r)
}
時間復雜度分析:
\(T(n)=2T(\frac{n}{2})+n \log_2n\), 總復雜度\(\Theta(n \log^2n)\)
下面是基於NTT的模板代碼(Luogu 4721)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 300000
#define G 3
#define invG 332748118
#define inv2 499122177
#define mod 998244353
using namespace std;
typedef long long ll;
inline ll fast_pow(ll x,ll k){
ll ans=1;
while(k){
if(k&1) ans=ans*x%mod;
x=x*x%mod;
k>>=1;
}
return ans;
}
inline ll inv(ll x){
return fast_pow(x,mod-2);
}
void NTT(ll *x,int n,int type){
static int rev[maxn+5];
int tn=1;
int k=0;
while(tn<n){
tn*=2;
k++;
}
for(int i=0;i<tn;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
for(int i=0;i<n;i++){
if(i<rev[i]) swap(x[i],x[rev[i]]);
}
for(int len=1;len<n;len*=2){
int sz=len*2;
ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz);
for(int l=0;l<n;l+=sz){
int r=l+len-1;
ll gnk=1;
for(int i=l;i<=r;i++){
ll tmp=x[i+len];
x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
x[i]=(x[i]+gnk*tmp%mod)%mod;
gnk=gnk*gn1%mod;
}
}
}
if(type==-1){
int invsz=inv(n);
for(int i=0;i<n;i++) x[i]=x[i]*invsz%mod;
}
}
void mul(ll *a,ll *b,ll *ans,int sz){
NTT(a,sz,1);
NTT(b,sz,1);
for(int i=0;i<sz;i++) ans[i]=a[i]*b[i]%mod;
NTT(ans,sz,-1);
}
void cdq_divide(ll *f,ll *g,int l,int r){
static ll tmpa[maxn+5],tmpb[maxn+5];
if(l==r) return;
int mid=(l+r)>>1;
cdq_divide(f,g,l,mid);
int tn=1,k=0;
while(tn<r-l){
k++;
tn*=2;
}
for(int i=0;i<tn;i++) tmpa[i]=tmpb[i]=0;
for(int i=l;i<=mid;i++) tmpa[i-l]=f[i];
for(int i=1;i<=r-l;i++) tmpb[i-1]=g[i];
mul(tmpa,tmpb,tmpa,tn);
for(int i=mid+1;i<=r;i++) f[i]=(f[i]+tmpa[i-l-1])%mod;
cdq_divide(f,g,mid+1,r);
}
int n;
ll f[maxn+5],g[maxn+5];
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++) scanf("%lld",&g[i]);
f[0]=1;
cdq_divide(f,g,0,n-1);
for(int i=0;i<n;i++) printf("%lld ",f[i]);
}
容易發現,許多dp方程都有分治FFT的形式。對於此類dp方程,我們可以用分治FFT將轉移復雜度由\(O(n^2)\)降到\(O(n \log^2 n)\)
例題
[Codeforces 553E]Kyoya and Train(期望DP+Floyd+分治FFT)
FFT的弱常數優化
下面介紹一些優化FFT的常數的技巧。雖然這些技巧都只是對FFT的一些小優化,但是在某些題目中優化效果極其明顯。
復雜算式中減少FFT次數
如果我們要計算一個復雜的多項式,如\(A(x)=B(x)C(x)+D(x)E(x)\)
最簡單的方法是分別計算\(B(x)C(x)\)和\(D(x)E(x)\),這樣需要做6次FFT. 但是如果先對\(B,C,D,E\)做DFT,然后直接用點值表達式計算\(a_i=b_ic_i+d_ie_i\),再把\(a\)IDFT回去。這樣只需要做5次FFT,且多項式越復雜,這樣的常數就越優秀。
例題
[BZOJ 3771] Triple(FFT+容斥原理+生成函數)
利用循環卷積
考慮對於兩個長度為\(n\)的序列\(a,b\),計算它們的卷積\(c\)的第\(0.5n\)項到第\(1.5n\)項。傳統的方法是補0擴充到\(2n\)的序列。但是因為FFT求得實際上是我們已經提到過的循環卷積,所以如果只補0到\(1.5n\)(上取整),對第\(0.5n\)項到第\(1.5n\)項無影響
在基於牛頓迭代的算法中,能起到較明顯的優化作用。會在(其四)中詳細介紹這些算法。
小范圍暴力
由於FFT的常數較大。在數據范圍較小的時候甚至不如\(O(n^2)\)的暴力卷積的優秀。因此在做多次FFT和分治FFT的時候,如果當前的序列長度較小,可以采用暴力算法。
例題
[BZOJ 3509] [CodeChef] COUNTARI (FFT+分塊)
快速冪乘法次數的優化
這個東西實際上比較雞肋。因為多項式快速冪可以通過多項式\(\ln\)和\(\exp\)優化到\(O(n \log n)\).但是為了應對考場上時間不夠的情況,我們來考慮如何通過簡單的實現來減少\(O(n \log^2n)\)的倍增快速冪的復雜度。
倍增法的思路是根據前面算過的乘積快速算出當前的乘積,如\(1 \to 2 \to 4 \to 8\).最壞情況下需要\(2 \log_2n+C\)次乘法。但這並不是下界。我們定義additional chain為一條鏈,最開始是1,后一個數減前一個數的差是鏈上這個是前面的某一個數。例如\(1 \to 2 \to 4 \to 6\).\(6-4=2\)在前面出現過,\(4-2=2\)在前面出現過。那么根據這條additional chain計算6次冪的時候,可以從1次冪出發,用1次冪乘1次冪得到2次冪,再乘2次冪得到4次冪,再乘2次冪得到6次冪。
很可惜,對於數\(k\)求出得到\(k\)的最短additional chain是NP-hard的。但是有很好的近似算法。近似算法基於BFS。每次我們對於隊頭的數\(x\),枚舉它對應的additional chain中的數\(y\),如果\(x+y\)還沒有訪問過那么將其入隊,並將\(x\)對應的鏈后面接上\(x+y\). 這個預處理是\(O(k)\)的,且對快速冪的常數優化很顯著。
如果\(k\)很大,比如\(10^{10000}\),可以采用十進制快速冪。但是用Method of Four Russians(俗稱四毛子算法),可以將乘法次數減少到\(\log_2n+O(\frac{\log n}{\log \log n})\).具體方法見2017年國家集訓隊論文《非常規大小分塊算法初探》
FFT的強常數優化
FFT的強常數優化一般是通過減少FFT次數來實現的
在這一節中,我們記\(DFT(A(x))\)表示多項式\(A(x)\)(或序列)做DFT之后的結果,\(IDFT(A(x))\)同理
我們現在考慮最常見的一個模型:給出兩個長度為\(n+1\)和\(m+1\)的多項式\(A(x),B(x)\),我們要計算他們的線性卷積。假設長度已經補齊為第一個大於\(n+m+1\)的2的整數冪\(L\)。
顯然直接搞需要3次長度為\(L\)的FFT。毒瘤的Vladimir Smykalov在cf上最先給出了這個問題的優化算法。
DFT的合並
DFT的合並是指,對於兩個序列\(a\),\(b\),我們只通過一次FFT就求出\(DFT(a),DFT(b)\)
不妨設:
接下來我們開始推導公式。注意為了簡潔,我們記\(X=\frac{2 \pi jk}{2L}\),\(\text{conj}(z)\)表示\(z\)的共軛復數
也就是說,只要一次DFT算出\(DFT(p)\),就可以把序列反轉再取共軛復數得到\(DFT(q)\).
由於DFT是線性變換,
其中\(j\)為\(k\)翻轉后的數,即\(j=\begin{cases}0,k=0 \\ L-k ,k>0 \end{cases}\)
又由\((4.1),(4.2)\)式
這樣我們就可以從\(q'\)推出\(a',b'\),也就是說一次DFT就能得到\(a'\)和\(b'\)了.
我們一共做了2次長度為\(L\)的FFT.
代碼(UOJ#34):
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("%lf + %lf i ",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
void mul(ll *a,ll *b,ll *c,int n){
static com p[maxn+5],r[maxn+5];
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));//預處理單位根
for(int i=0;i<n;i++) p[i]=com(a[i],b[i]);//p[i]=a[i]+ib[i]
fft(p,n);
for(int i=0;i<n;i++){
int j=(i>0?(n-i):0);//0的位置需要特判一下
com q=p[j];
r[j]=(p[i]*p[i]-q.conj()*q.conj())*com(0,-0.25);//按照上面的式子
}
fft(r,n);//這里是用了第一篇中提到的反轉技巧
for(int i=0;i<n;i++) c[i]=r[i].real/n+0.5;
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<n+m+1){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]);
}
IDFT的合並
IDFT的合並是指,對於兩個序列\(a\),\(b\),我們只通過一次FFT就求出\(IDFT(a),IDFT(b)\)
IDFT的合並非常簡單。
設\(r(x)=a(x)+\text{i}b(x)\)
由於IDFT是線性變換
\(IDFT(r(x))=IDFT(a(x))+\text{i}IDFT(b(x))\)
又因為\(a(x)\)和\(b(x)\)都是實數序列,那么\(IDFT(r(x))\)的實部就是\(IDFT(a(x))\),虛部就是\(IDFT(b(x))\)
形如\((A+B)(C+D)\)的卷積的優化
在這一節中我們討論\((A(x)+B(x))(C(x)+D(x))\)形式的卷積的優化.
一般的做法是對\(A,B,C,D\)都做一次DFT,然后按照這個式子直接計算,最后再IDFT回來。需要5次FFT.
而根據上面的合並技巧,先把\(A(x),B(x)\)合並DFT,\(C(x),D(x)\)合並DFT得到點值表達式.
由於\((A(x)+B(x))(C(x)+D(x))=A(x)C(x)+A(x)D(x)+B(x)C(x)+B(x)D(x)\)
我們可以直接把點值表達式相乘得到這4個多項式。對於這4個多項式,分成2組合並做IDFT即可。
總共需要4次FFT.
大致代碼如下:
void mul(ll *a,ll *b,ll *c,ll *d,ll *ans,int n){
static com p[maxn+5],q[maxn+5];
static com r[maxn+5],s[maxn+5];
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
for(int i=0;i<n;i++){
p[i]=com(a[i],b[i]);//打包A,B
q[i]=com(c[i],d[i]);//打包C,D
}
fft(p,n);
fft(q,n);
for(int i=0;i<n;i++){
int j=(i==0?0:n-i);
//得到DFT(A),DFT(B),DFT(C),DFT(D)
com da=(p[i]+p[j].conj())*0.5;
com db=(p[i]-p[j].conj())*com(0,-0.5);
com dc=(q[i]+q[j].conj())*0.5;
com dd=(q[i]-q[j].conj())*com(0,-0.5);
r[j]=da*dc+da*dd*com(0,1);//打包AC,AD
s[j]=db*dc+db*dd*com(0,1); //打包BC,BD
}
fft(r,n);
fft(s,n);
for(int i=0;i<n;i++){
ll ac,ad,bc,bd;
ac=(ll)(r[i].real/n+0.5);
ad=(ll)(r[i].imag/n+0.5);
bc=(ll)(s[i].real/n+0.5);
bd=(ll)(s[i].imag/n+0.5);
ans[i]=ac+ad+bc+bd;
}
}
卷積的終極優化
上述優化中我們只用到了DFT的思想。現在我們利用FFT的思想繼續優化
同樣拆分奇偶項,\(A(x)=A_0(x^2)+xA_1(x^2)\)
我們只需要知道上式中\(x^0,x^1,x^2\)的系數
發現\(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2)\)是奇數項的系數,\(A_0(x^2)B_0(x^2)\)和\(A_1(x^2)B_1(x^2)\)是偶數項的系數,而偶數項的兩個東西都可以看成一個關於\(x^2\)的多項式。
我們先優化DFT的過程,觀察\((4.6)\)式的乘積形式\((A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\).
我們發現,這個形式和上一節的\((A+B)(C+D)\)很像,可以類似地優化。
令\(p_k={a_0}_k+\text{i}{a_1}_k,q_k={b_0}_k+\text{i}{b_1}_k\)
然后合並IDFT,再設兩個輔助多項式
(注意我們把\(x^2\)換元成\(x\),做DFT的時候要乘上單位根)
那么我們只需要計算出\(IDFT(G(x))\)和\(IDFT(F(x))\)
設\(R(x)=G(x)+\mathrm{i} F(x)\)
那么因為IDFT是線性變換,\(IDFT(R(x))=IDFT(G(x))+\mathrm{i} IDFT(F(x))\)
(IDFT的線性性這里不做證明,容易發現兩個點值表達式相加再IDFT回來,顯然系數也會相加)
顯然這兩個多項式IDFT的結果是實數。故我們只要求出\(IDFT(R(x))\),每一項系數的實部就是偶數項系數\(G(x)\),虛部就是奇數項系數\(F(x)\)
我們再考慮把合並DFT弄進去,即式\((4.3)(4.4)(4.5)\)
接下來我們嘗試用\(DFT(p_k),DFT(q_k)\)來表示\(R(x)=G(x)+\text{i}F(x)\),為了推導簡潔,我們省略\(DFT\)不寫
那么
和上一節的\((A+B)(C+D)\)不同,我們只用了3次長度為\(L/2\)的FFT,就求出了答案,這是由於FFT本身的性質。因為長度縮減了一半,我們不妨稱它為\(1.5\)次FFT.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("%lf + %lf i ",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
void mul(ll *a,ll *b,ll *c,int n){
static com p[maxn+5],q[maxn+5],r[maxn+5];
for(int i=0;i<n;i++){//合並做DFT
if(i%2==1){
p[i/2].imag=a[i];
q[i/2].imag=b[i];
}else{
p[i/2].real=a[i];
q[i/2].real=b[i];
}
}
n/=2;
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
fft(q,n);
fft(p,n);
for(int i=0;i<n;i++){
int j=(i>0?(n-i):0);
r[j]=p[i]*q[i]-(w[i]+1)*(p[i]-p[j].conj())*(q[i]-q[j].conj())*0.25;
}
fft(r,n);
for(int i=0;i<n;i++){
c[i*2]=r[i].real/n+0.5;
c[i*2+1]=r[i].imag/n+0.5;
}
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<=n+m+1){
L++;
N*=2;
}
for(int i=0;i<N/2;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-2));//注意這里的rev數組是對N/2做的,L要-1
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]);
}
任意模數NTT
三模數NTT
這是任意模數NTT的算法中最好理解的一種,它基於中國剩余定理。
定理5.1 若\(m_1,m_2 ,\dots m_n\)兩兩互質,則對於\(\forall a_1,a_2 \dots a_n\)同余方程組
\[\begin{cases} x \equiv a_1 (\bmod m_1) \\ x \equiv a_2 (\bmod m_2) \\ \dots \\ x \equiv a_n (\bmod m_n)\end{cases} \]有整數解解,且可以用如下方式構造解
- 設\(M=\prod_{i=1}^n m_i,M_i=\frac{M}{m_i}\)
- 設\(M_i^{-1}\)為模\(m_i\)意義下\(M_i\)的逆元
- 則該方程組在模\(M\)意義下的唯一解為\(x=\sum_{i=1}^n a_iM_iM_i^{-1}\) ,方程組的通解可以表示為\(x+kM(k \in \mathbb{Z})\)
這就是著名的中國剩余定理(Chinese Reminder Theorem,CRT)
證明:
對於\(k \neq i\),\(a_iM_iM_i^{-1} \bmod m_k=0\), 而根據逆元的定義,\(a_iM_iM_i^{-1} \bmod m_i =a_i\). 再代入到\(\sum_{i=1}^n a_iM_iM_i^{-1}\),原方程組成立。
回到任意模數NTT問題
模\(M\)意義下長度為\(n\)的序列做卷積,最大值可以到\(n^2M\).一般的題目中\(n \leq 10^5,M\leq 10^{9}\),那么結果會到\(10^{23}\)級別。用long double
等存儲會丟失精度。那么我們可以選三個乘起來大於\(10^{23}\)的NTT模數998244353,1004535809,469762049(選這三個模數的好處是他們的原根都是3,所以NTT部分寫起來比較簡潔)。然后分別在這三個模數的意義下做卷積。最后考慮把答案合並,我們只考慮某一位上的值\(ans\),容易寫出:
顯然\(m_1,m_2,m_3\)互質,那么我們可以利用中國剩余定理直接合並。但是,直接合並把三個模數乘起來的時候會超出long long
的范圍。注意到兩個模數相乘還是在long long
范圍內的,可以兩兩合並,具體方法如下,
記\(inv(a,m)\)表示\(a\)在模\(m\)下的逆元.根據CRT合並\((5.2)(5.3)\)有:
不妨設\(ans=km_1m_2+r\),根據\(5.4\)有
\(ans=km_1 m_2+r=q m_3+a_3 \tag{5.6}\),
在模 \(m_3\) 意義下有
\(km_1 m_2+r \equiv a_3 (\bmod m_3) \tag{5.7}\)
因此\(k=(a_3-r_2)inv(m_1m_2,m_3) (\bmod m_3)\),不妨設\(k=dm_3+e\),代入\(5.6\)得
由於\(m_1m_2m_3>ans\),所以\(d=0\),也就是說,\(ans=em_1m_2+r\),其中\(r=a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2),e=(a_3-r_2)inv(m_1m_2,m_3)\)
const ll mm=m1*m2;
inline ll inv(ll a,ll m);
ll mul(ll a,ll b,ll m);//要用按位乘防止溢出
ll CRT(ll a1,ll a2,ll a3){
ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
return ((e%C)*(mm%C)%C+r%C)%C;
}
完整代碼(LuoguP4245 【模板】任意模數NTT)
#include<iostream>
#include<cstdio>
#include<cstring>
#define m1 998244353ll
#define m2 1004535809ll
#define m3 469762049ll
#define G 3
#define maxn 1048576
using namespace std;
typedef long long ll;
const ll mm=m1*m2;
ll C;
ll fast_pow(ll x,ll k,ll m){
ll ans=1;
while(k){
if(k&1) ans=ans*x%m;
x=x*x%m;
k>>=1;
}
return ans;
}
inline ll inv(ll a,ll m){
return fast_pow(a%m,m-2,m); //一定要取模m
}
ll mul(ll a,ll b,ll m){
ll ans=0;
while(b){
if(b&1) ans=(ans+a)%m;
a=(a+a)%m;
b>>=1;
}
return ans;
}
ll CRT(ll a1,ll a2,ll a3){
//[Warning]You are not expected to understand this.
ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
return ((e%C)*(mm%C)%C+r%C)%C;
}
int n,m,N,L;
int rev[maxn+5];
void NTT(ll *x,int n,int type,ll mod){
ll invG=inv(G,mod);
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz,mod);
for(int l=0;l<n;l+=sz){
int r=l+len-1;
ll gnk=1;
for(int i=l;i<=r;i++){
ll tmp=x[i+len];
x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
x[i]=(x[i]+gnk*tmp%mod)%mod;
gnk=gnk*gn1%mod;
}
}
}
if(type==-1){
ll invn=inv(n,mod);
for(int i=0;i<n;i++) x[i]=x[i]*invn%mod;
}
}
void fmul(ll *a,ll *b,ll *ans,int n,ll mod){
static ll ta[maxn+5],tb[maxn+5];
for(int i=0;i<n;i++) ta[i]=a[i];
for(int i=0;i<n;i++) tb[i]=b[i];
NTT(ta,n,1,mod);
if(a!=b) NTT(tb,n,1,mod);
for(int i=0;i<n;i++) ans[i]=ta[i]*tb[i]%mod;
NTT(ans,n,-1,mod);
}
ll a[maxn+5],b[maxn+5],c[3][maxn+5];
int main(){
scanf("%d %d %lld",&n,&m,&C);
for(int i=0;i<=n;i++){
scanf("%lld",&a[i]);
a[i]%=C;
}
for(int i=0;i<=m;i++){
scanf("%lld",&b[i]);
b[i]%=C;
}
N=1,L=0;
while(N<n+m+1){
N*=2;
L++;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
fmul(a,b,c[0],N,m1);
fmul(a,b,c[1],N,m2);
fmul(a,b,c[2],N,m3);
for(int i=0;i<n+m+1;i++){
printf("%lld ",CRT(c[0][i],c[1][i],c[2][i]));
}
}
容易發現,三模數NTT需要9次FFT,不是很優秀
拆系數FFT
我們之前討論的優化都是針對FFT的,那不妨嘗試用FFT解決任意模數NTT
最簡單的想法是不取模,FFT完再取模。但是上文提到數值過大,long double
會丟失精度。
int128
是一個方法,但在OI比賽中不一定能使用。所以需要拆系數。
設\(M_0=[\sqrt{M}]\)
相當於把模數換成\(M_0\),降低大小。
代入對應的多項式
這不就是我們提到的\((A+B)(C+D)\)形的卷積嗎?
由於\(k,b\)都不超過\(2^{15}\),於是就不容易被卡精度了。實際操作中我們不必取\(M_0=\sqrt{M}\),直接取\(M_0=2^{15}\)即可。這樣取模運算可以換成位運算,進一步減小常數。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("(%lf,%lf)\n",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
ll mod;
void mul(ll *ina,ll *inb,ll *inc,int n){
static ll a[maxn+5],b[maxn+5],c[maxn+5],d[maxn+5];
static com p[maxn+5],q[maxn+5];
static com r[maxn+5],s[maxn+5];
for(int i=0;i<n;i++){
ina[i]=(ina[i]+mod)%mod;
inb[i]=(inb[i]+mod)%mod;
a[i]=ina[i]>>15;
b[i]=ina[i]&((1<<15)-1);
c[i]=inb[i]>>15;
d[i]=inb[i]&((1<<15)-1);
}
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
for(int i=0;i<n;i++){
p[i]=com(a[i],b[i]);//打包A,B
q[i]=com(c[i],d[i]);//打包C,D
}
fft(p,n);
fft(q,n);
for(int i=0;i<n;i++){
// p[i].print();
int j=(i==0?0:n-i);
//得到DFT(A),DFT(B),DFT(C),DFT(D)
com da=(p[i]+p[j].conj())*0.5;
com db=(p[i]-p[j].conj())*com(0,-0.5);
com dc=(q[i]+q[j].conj())*0.5;
com dd=(q[i]-q[j].conj())*com(0,-0.5);
r[j]=da*dc+da*dd*com(0,1);//打包AC,AD
s[j]=db*dc+db*dd*com(0,1); //打包BC,BD
}
fft(r,n);
fft(s,n);
for(int i=0;i<n;i++){
ll ac,ad,bc,bd;
ac=(ll)(r[i].real/n+0.5)%mod;
ad=(ll)(r[i].imag/n+0.5)%mod;
bc=(ll)(s[i].real/n+0.5)%mod;
bd=(ll)(s[i].imag/n+0.5)%mod;
inc[i]=((ac<<30)+((ad+bc)<<15)+bd)%mod;
}
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d %lld",&n,&m,&mod);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<=n+m+1){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld ",c[i]);
}