FFT(快速傅立葉變換)和NTT(快速數論變換)看上去很高端,真正搞懂了就很simple了辣。
首先給出多項式的一些定義(初中數學內容):
形如Σaixi的式子就是多項式!
多項式中每個單項式叫做多項式的項。
這些單項式中的最高次數,就是這個多項式的次數。
有幾個不同的元也是多項式,但在下面將不被考慮。
注意:(n+1)個點可以唯一確定一個n次多項式(兩點定線啊之類的)。
然后就是一些比較高明的東西了。
首先在掌握FFT之前我們要掌握一下知識:
1.復數的計算法則.
形如(a+bi)的數叫復數,分為實部和虛部。
i是這么一個東西:i*i+1=0,虛數單位。
復數的加減法:實部虛部分別相加減。
復數的乘法:(a+bi)*(c+di)=(ac-bd)+(ad+bc)i;
除法太難打所以請戳這里。
2.復數的表達形式
感謝一位叫卜卜的熱心網友大晚上不看番教我數學。
第一種形式就是代數式:(a+bi),高中數學內容。
第二種形式也許?叫三角式:r(cosθ+isinθ)。
具體來說,將代數式里的a,b放到二維笛卡爾坐標系平面直角坐標系里,橫坐標為實部,縱坐標為虛部。把原點和(a,b)相連,記這條向量與X軸的夾角為θ,模長為r,上面那個式子就很好理解了。
那么來看看三角式下的乘法運算?
r1(cosθ1+isinθ1)*r2(cosθ2+isinθ2) = r1r2(cos(θ1+θ2)+isin(θ1+θ2))
沒錯就是這樣。
於是就有顯而易見的n次方式:
(r(cosθ+isinθ))^n=r^n(cos(nθ)+sin(nθ))
這在FFT中會用到。
還有一個公式是(cosθ+isinθ)=eiθ。推理過程要用到。
然后是多項式乘法。一個n次的多項式乘上一個m次的多項式,結果是(n+m)次的。
朴素的多項式相乘時間復雜度是O(n^2)的,不夠優秀。
而FFT則是利用了單位復根的優秀性質來解決了這么一個問題。
首先我們需要把多項式轉化成點值表示法,稱為求值。其逆過程稱為插值。
這樣有一個好處:
兩個多項式A,B分別取點(X,Ya)和(X,Yb),A×B就會取到點(X,Ya*Yb);
具體是什么原因?我認為生命需要留下一點遺憾(嘖)。
其實很好理解。
T(x)=f(x)*g(x),所以T(3)=f(3)*g(3)。
顯而易見。
所以轉化成點值表示法后,"相乘"反倒成為最簡單的了。
所以多項式相乘的基本步驟:
對A,B求值 » 點值乘法 » 插值。
若能將求值和插值的復雜度降低,就能達到我們的目的了!
FFT的核心思想:
通過恰當選取x的值,並采用分治策略使得求值和插值的復雜度降下來。
首先我們要了解的是n次單位復數根。
記為Wn...Wnn。Wnn = 1 = 1+0*i;
並且有n次單位復數根的個數為n。
算法導論告訴我們,nn個單位復數根均勻的分布在以復平面的原點為圓心的單位半徑的圓周上。
憋問我為什么
(Pic from Xlightgod)
記Wn=r(cosθ+isinθ)。
那么我們可以得知:Wnn=r^n(cos(nθ)+i*sin(nθ))=(1+0*i);
r=1,然后你稍稍推一下就知道θ=0。
設nθ=φ+2kπ,則φ=θ/n+2kπ/n;
因為θ=0,所以就是2kπ/n有值。
所以φ=2kπ/n;
sin和cos都是以2π為周期的,所以可以用φ代替θ
所以Wn=cosφ+isinφ=e(i*2kπ/n)。
接下來就可以證明一個重要的定理:
消去定理:Wakbk=(e(i*2kπ/ak))bk=e(i*2kπ*bk/ak)=e(i*2kπ*b/a)=Wab;
然后用這個定理可以證明:
折半定理:(Wnk)2=e(2*2kπ/n)=e(2kπ/(n/2))=Wn/2k;
這樣的話,一次平方下來,取值就少一倍。
接下來就是很簡單的 分治 了。
《論折半定理在信息學競賽中的簡單應用》 傅立葉
把這個多項式A(x)=Σaixi分治一下,構建新的多項式。
A[0](x)=a0+a2x+a4x2+...+an-2x(n-2)/2;
A[1](x)=a1+a3x+a5x2+...+an-1x(n-1)/2;
A(x)=A[0](x2)+x*A[1](x2);
因為這個是嚴格分治的,所以最高次項必須要是2的n次方。
(你問我n不是2的冪怎么辦?擴大一下,高位系數全為0不就完了
所以說常數大得要死。
所以我們利用快速傅立葉變換求出了離散傅立葉變換(DFT)。
好像是把求值叫DFT,把插值叫IDFT。
然后又有人證明出插值只要將Wn變成Wn-1,再將結果除以n即可。
再做一遍FFT就可以了。
Congratulations! 復雜度已經被我們降到了O((n+m)log(n+m))。
代碼實現起來竟然這么短,關鍵語句只有9行!
//uoj模板題
#include <iostream> #include <cstdio> #include <cstdlib> #include <algorithm> #include <vector> #include <cstring> #include <queue> #include <cmath> #include <complex> #define LL long long int using namespace std; const int N = 262145; const double pi = acos(-1.0); typedef complex<double> dob; int n,m; dob a[N],b[N]; int gi() { int x=0,res=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')res*=-1;ch=getchar();} while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar(); return x*res; } inline void FFT(dob *A,int len,int f) { if(len==1)return; dob wn(cos(2.0*pi/len),sin(f*2.0*pi/len)),w(1,0),t; dob A0[len>>1],A1[len>>1]; for(int i=0;i<(len>>1);++i)A0[i]=A[i<<1],A1[i]=A[i<<1|1]; FFT(A0,len>>1,f);FFT(A1,len>>1,f); for(int i=0;i<(len>>1);++i,w*=wn){ t=w*A1[i]; A[i]=A0[i]+t; A[i+(len>>1)]=A0[i]-t; } } int main() { n=gi();m=gi(); for(int i=0;i<=n;++i)a[i]=gi(); for(int i=0;i<=m;++i)b[i]=gi(); m+=n; for(n=1;n<=m;n<<=1); FFT(a,n,1);FFT(b,n,1); for(int i=0;i<=n;++i)a[i]*=b[i]; FFT(a,n,-1); for(int i=0;i<=m;++i) printf("%d ",int(a[i].real()/n+0.5)); return 0; }
然而我們早就知道遞歸有着巨大的常數,加上FFT的巨大常數(三角函數計算),導致奇慢無比。
我們來欣賞一下這個美麗的蝴蝶遞歸。
把最后一行的數化成二進制:
000,100,010,110,001,101,011,111;
然后把每一個數順序反過來:
000,001,010,011,100,101,110,111;
是個遞增的對不對?十分優美對不對?
優美個鬼啊
於是就有人喜(喪)大(心)普(病)奔(狂)推出了三層for人工合並的東西。
#include <iostream> #include <cstdio> #include <cstdlib> #include <algorithm> #include <vector> #include <cstring> #include <queue> #include <cmath> #include <complex> #define LL long long int using namespace std; const int N = 262145; const double pi = acos(-1.0); typedef complex<double> dob; int n,m,L,R[N]; dob a[N],b[N]; inline int gi() { int x=0,res=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')res*=-1;ch=getchar();} while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar(); return x*res; } inline void FFT(dob *A,int f) { for(int i=0;i<n;++i)if(i<R[i])swap(A[i],A[R[i]]); for(int i=1;i<n;i<<=1){ dob wn(cos(pi/i),sin(f*pi/i)),x,y; for(int j=0;j<n;j+=(i<<1)){ dob w(1,0); for(int k=0;k<i;++k,w*=wn){ x=A[j+k];y=w*A[j+i+k]; A[j+k]=x+y; A[j+i+k]=x-y; } } } } int main() { n=gi();m=gi(); for(int i=0;i<=n;++i)a[i]=gi(); for(int i=0;i<=m;++i)b[i]=gi(); m+=n; for(n=1;n<=m;n<<=1)++L; for(int i=0;i<n;++i)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1)); FFT(a,1);FFT(b,1); for(int i=0;i<=n;++i)a[i]*=b[i]; FFT(a,-1); for(int i=0;i<=m;++i)printf("%d ",int(a[i].real()/n+0.5)); return 0; }
第一層i是枚舉合並到了哪一層。
第二層j是枚舉合並區間。
第三層k是枚舉區間內的下標。
j*k=(n+m);i是log級的。
所以說復雜度沒變,常數降下來了。
其實常數還能降一點。
1.勿使用系統的復數庫,自己手寫結構體,只需重載加減乘即可,大概可以壓到原來時間的60%。
2.預處理所有要用到的單位復根及其冪(用三角函數式計算),這樣還可以保證精度(cogs 釋迦,不這么寫必掛),大概卷積上界達到1e14就需要預處理了。
至於更多FFT技巧,可以移步myy2016的集訓隊論文。
我們不得不承認FFT是一個優秀而鬼畜的東西。
因為有三角函數和浮點數的參與,FFT有時候會出現尷尬的爆精度現象。
這種病醫生說是救不了的。
有些題目要求答案要對一個質數取模(998244353),我們知道取模是數論內容。
那么有沒有什么東西可以替代單位復根呢?
當然有!原根!
設原根為g。
Wnn≡gP-1≡1(mod P);
所以可以把g(P-1)/n看成Wn的等價。
好的NTT學完了。
所以說這種質數必須是NTT質數(費馬質數),即(P-1)有超過序列長度的2的正整數冪因子的質數,如998244353,1004535809,469762049等。
不是這種質數怎么辦?找幾個找乘積大於p^2*n的費馬質數做,用中國剩余定理合並就好了。
#include <iostream> #include <cstdio> #include <cstdlib> #include <algorithm> #include <vector> #include <cstring> #include <queue> #include <complex> #include <stack> #define LL long long int #define ls (x << 1) #define rs (x << 1 | 1) #define MID int mid=(l+r)>>1 using namespace std; const int N = 300010; const int Mod = 998244353; int n,m,L,R[N],g[N],a[N],b[N]; int gi() { int x=0,res=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')res*=-1;ch=getchar();} while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar(); return x*res; } inline int QPow(int d,int z) { int ans=1; for(;z;z>>=1,d=1ll*d*d%Mod) if(z&1)ans=1ll*ans*d%Mod; return ans; } inline void NTT(int *A,int f) { for(int i=0;i<n;++i)if(i<R[i])swap(A[i],A[R[i]]); for(int i=1;i<n;i<<=1){ int gn=QPow(3,(Mod-1)/(i<<1)),x,y; for(int j=0;j<n;j+=(i<<1)){ int g=1; for(int k=0;k<i;++k,g=1ll*g*gn%Mod){ x=A[j+k];y=1ll*g*A[i+j+k]%Mod; A[j+k]=(x+y)%Mod;A[i+j+k]=(x-y+Mod)%Mod; } } } if(f==1)return;reverse(A+1,A+n); int y=QPow(n,Mod-2); for(int i=0;i<n;++i)A[i]=1ll*A[i]*y%Mod; } int main() { n=gi();m=gi(); for(int i=0;i<=n;++i)a[i]=gi(); for(int i=0;i<=m;++i)b[i]=gi(); m+=n;for(n=1;n<=m;n<<=1)++L; for(int i=0;i<n;++i)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1)); NTT(a,1);NTT(b,1); for(int i=0;i<n;++i)a[i]=1ll*a[i]*b[i]%Mod; NTT(a,-1); for(int i=0;i<=m;++i)printf("%d ",a[i]); printf("\n"); return 0; }