建議同學們先自學一下“復數(虛數)”的性質、運算等知識,不然看這篇文章有很大概率看不懂。
前言
作為一個典型的蒟蒻,別人的博客都看不懂,只好自己寫一篇了。
膜拜機房大佬 HY
一. FFT是蛤??
FFT (快速傅里葉變換) 的作用是在 O(nlogn) 時間算出多項式乘法的一個特別神奇的算法。
大家平時碼的多項式乘法都是 O(n^2) 的吧
1 #include<iostream> 2 #include<cstdio> 3 using namespace std; 4 5 int n,m,a[10005],b[10005],c[20005]; 6 7 int main(){ 8 scanf("%d%d",&n,&m); 9 for(int i=0;i<n;i++)scanf("%d",a+i); 10 for(int i=0;i<m;i++)scanf("%d",b+i); 11 for(int i=0;i<n;i++) 12 for(int j=0;j<m;j++)c[i+j]+=a[i]*b[j]; 13 for(int i=0;i<n+m-1;i++) 14 printf("%d ",c[i]); 15 }
但這個算法並不能解決什么問題。
n<=100000 恭喜你,你成功TLE了,這時就要用到FFT了!!(是不是很激動?)
二. 算法思想
相信大家十分想知道這神奇的算法是怎么工作的。我們平時表達多項式的方法是系數表示法,而我們要把這個多項式換成另一個神奇的表達方法——點值表示法。這種神奇的表示法可以在 O(n) 的時間內算出多項式乘法,可是很遺憾,要想讓這兩種表示法互相轉化任是需要 O(n^2) 的時間,而FFT的核心就是在 O(nlogn) 的時間內實現轉換。
三. 系數表示法和點值表示法
系數表示法
就是用一個多項式的各個項的系數表示這個多項式,也就是我們平時所用的表示法。例如,我們可以這樣表示:
f(x)=a0+a1x1+a2x2+..+anxn⇔f(x)={a0,a1,a2,..,an}
這就像是我們用數組存一個多項式一樣
點值表示法
就是把這個多項式理解成一個函數,用這個函數上的若干個點的坐標來描述這個多項式。(兩點確定一條直線,三點確定一條拋物線…同理n+1個點確定一個n次函數)
因此表示成這樣:(注意:x[0]->x[n]是n+1個點)
f(x)=a0+a1x+a2x2+..+anxn⇔f(x)={(x0,y0),(x1,y1),(x2,y2),..,(xn,yn)}
為什么n+1個確定的點能確定一個唯一的多項式呢?你可以嘗試着把這n+1個點的值分別代入多項式中:
如圖,我們把相應的 x 與 y 的值代入,就能的到n+1個方程,也就能解出n+1個位置數,即數組 a,這樣也就確定了一個多項式。
四. 點值表達式的乘法
現在,考慮這樣一個問題,如果我有兩個用點值表示的多項式,如何表示它們兩個多項式的乘積呢?
我們令這兩個點值表達式的 x 值相等,則會有一組唯一確定的 y 值。
結果F(x)=f(x)×g(x),那么就有F(x0)=f(x0)∗g(x0)(x0x0為任意數)。
思考一下,很容易得出,如果 x 的取值相同,結果多項式的值就是兩個因式的值的乘積
也就是說,如果我把兩個函數的點值表示法中的 x 值相同的點的 y 值乘在一起就是它們的乘積(新函數)的點值表示。
這就可以O(n)計算多項式乘法。
五. 復數
我們把形如a+bi(a,b均為實數)的數稱為復數,其中a稱為實部,b稱為虛部,i稱為虛數單位。當虛部等於零時,這個復數可以視為實數;當z的虛部不等於零時,實部等於零時,常稱z為純虛數。 ————百度百科
這是復數的定義,不過為什么要用復數呢??除非作者腦子有問題,不然肯定不會講無關的東西
雖然點值表達式的乘法是O(n)的,可我們求的是系數表達式,而系數表達式與點值表達式的轉換卻是O(n^2)
復數的引用可以對這里進行優化。優化的方法我們下面再說。
我們給定一個坐標系,橫軸表示 a,縱軸表示 b,這樣所有的復數都可以在這里表示出來,這便是復數的幾何意義。更多關於復數的內容請自行了解在這就不闡述了
然后,思考一個簡單的問題:兩個復數的乘法有沒有某種特定的幾何意義?(只是一個數學性質,在此不進一步深究,可用三角函數證明。)
如圖可得,復數的乘法,長度相乘,極角相加。
六. 單位復根
現在,回到我們剛才講到的“點值表示法”的問題,要想轉化,也就是要解一個n+1元的方程組。
當我們計算x0,x02,...,x0n 時會浪費大量的時間。這個數學運算看似是沒有辦法加速的,而實際上我們可以找到一種神奇的“x值”,帶進去之后不用反復地去做無用n次方操作,比如 1 與 -1,可以加速。
但是我們要至少帶進去n+1個不同的數才能進行系數表示。這時就要用到復數了!
我們需要的是滿足“ωk=1”的數(k為整數)
看上圖中的紅圈,紅圈上的每一個點距原點的距離都是1個單位長度,所以說如果說對這些點做k次方運算,它們始終不會脫離這個紅圈。
因為它們在相乘的時候r始終=1,只是θ的大小在發生改變。而這些點中有無數個點經過k次方之后可以回到“1”。
因此,我們可以把這樣的一組神奇的x帶入函數求值。像這種能夠通過k次方運算回到“1”的數,我們叫它“復根”用“ω”表示。
你會發現:其實k次負根就相當於是給圖中的圓周平均分成k個弧,弧與弧之間的端點就是“復根”,我們只需要知道ωn1,就能求出ωnk。所以我們稱“ωn1”為“單位復根”。
其實,我們用“ωk”表示單位復根,ωk1表示的是“單位復根”的“1次方”也就是它本身,其他的就叫做 k 次單位復根的 n 次方。
七. FFT 之 DFT
前面的復數都是數學的內容,所以講的比較簡略,不過也終於到正題 FFT 了!
DFT 是 FFT 中將系數表達式轉變為點值表達式的過程。
我們把多項式的系數表達式,換成 x 值 為 ωnk 的點值表達式。
f(x)=a0 + a1*x + a2*x2 + a3*x3 + ...... + an-1*xn-1
f(x)={(ωn0,y0),(ωn1,y1),(ωn2,y2), ...... ,(ωnn-1,yn-1)}
然后我們可以將 x 值(ωnk)省略,只儲存 y0 , y1 , ......, yn
可我們將 ωnk 帶入 x 后又能怎么優化呢?我們可以嘗試一下分治思想。
將 ωnk 和 ωnk+n/2 代入,就可以發現一個神奇的現象
F( ωnk )=G(ωn2k)+ωnk * H(ωn2k)
=G(ωn2k) + ωnk * H(ωn2k)
=G(ωn/2k) + ωnk * H(ωn/2k)
F(ωnk+n/2)=G(ωn2k+n) + ωnk+n/2 * H(ωn2k+n)
=G(ωn2k * ωnn) - ωnk * H(ωn2k * ωnn)
=G(ωn2k) - ωnk * H(ωn2k)
=G(ωn/2k) - ωnk * H(ωn/2k)
沒想到得出來的式子竟然這么相近,也就是說,我們把其中一個值帶入,就可以的到另一個,我們就可以把時間縮小一半了。
接下來就可以遞歸求解了!!!
1 const double PI=acos(-1); //圓周率 π
2 typedef complex<double> cmplx;//我比較懶,就用了STL自帶的復數類
3 void DFT(int len,cmplx a[]){ 4 if(len==1)return; //只有一個常數項
5 cmplx a1[len>>1],a2[len>>1]; 6 for(int i=0;i<=len;i+=2) //根據下標的奇偶性分類
7 a1[i>>1]=a[i],a2[i>>1]=a[i+1]; 8 FFT(len>>1,a1),FFT(len>>1,a2); 9 cmplx W=exp(cmplx(0,PI/len));//求為單位根ω
10 cmplx w=cmplx(1,0); //w表示0~n-1次冪,初始為0次冪 1
11 for(int i=0;i<(len>>1);i++,w=w*W){ 12 a[i]=a1[i]+w*a2[i]; //上文我們推導的性質
13 a[i+(len>>1)]=a1[i]-w*a2[i]; 14 //利用單位根的性質,O(1)得到另一部分
15 } 16 }
是不是很友好?可是遞歸實現的缺點也很顯著,空間都消耗巨大,所以我們就要模擬遞歸了。
遞歸的時候,我們是將多項式奇偶拆開,如圖
這看似拆出來沒什么規律,但我們試着把數換為二進制,又會發生什么呢?
拆完后的多項式竟然是原來的二進制翻轉!!!我們就可以這樣通過倍增來模擬遞歸了!!!
1 const double PI=acos(-1); 2 typedef complex<double> cmplx; 3
4 void get_rev(){ //求二進制反轉
5 while(bit<=n)bit<<=1; //bit為最大二進制位長度的值
6 for(int i=0;i<bit;i++) 7 rev[i]=(rev[i>>1]>>1)|(i&1)*(bit>>1); 8 } 9
10 void DFT(cmplx a[]){ 11 for(int i=0;i<bit;i++) 12 if(i<rev[i])swap(a[i],a[rev[i]]); 13 //根據rev數組進行二進制反轉
14 for(int i=1;i<bit;i<<=1){ //倍增模擬遞歸
15 cmplx W=exp(cmplx(0,PI/i)); 16 for(int j=0;j<bit;j+=i<<1){ //一組一組處理
17 cmplx w(1,0); //同遞歸版代碼
18 for(int k=j;k<j+i;k++,w*=W){ //同遞歸版代碼
19 cmplx x=a[k]; 20 cmplx y=w*a[k+i]; 21 a[k]=x+y,a[k+i]=x-y; 22 } 23 } 24 } 25 }
雖然丑了,不過優秀了許多。
八. FFT 之 IDFT
我們將系數表示法轉為點值表示法,總要把它變回來,而變回來的過程就是 IDFT 了。
IDFT似乎要矩陣的知識證明(而我不會,尷不尷尬),於是乎,我就只亮一波代碼好了!
1 void IDFT(cmplx a[]){ 2 for(int i=0;i<bit;i++) 3 if(i<rev[i])swap(a[i],a[rev[i]]); 4 for(int i=1;i<bit;i<<=1){ 5 cmplx W=exp(cmplx(0,-PI/i)); 6 for(int j=0;j<bit;j+=i<<1){ 7 cmplx w(1,0); 8 for(int k=j;k<j+i;k++,w*=W){ 9 cmplx x=a[k]; 10 cmplx y=w*a[k+i]; 11 a[k]=x+y,a[k+i]=x-y; 12 } 13 } 14 } 15 for(int i=0;i<bit;i++)a[i]/=bit; 16 }
你會發現,這只是幾個符號的差別(所以你也不是很必要知道原理了吧,好奇的同學只能自己探索了)
其實我們可以吧 DFT 和 IDFT 合並成一個函數
1 void FFT(cmplx a[],int dft){ //1是DFT,-1是IDFT
2 for(int i=0;i<bit;i++) 3 if(i<rev[i])swap(a[i],a[rev[i]]); 4 for(int i=1;i<bit;i<<=1){ 5 cmplx W=exp(cmplx(0,dft*PI/i)); 6 for(int j=0;j<bit;j+=i<<1){ 7 cmplx w(1,0); 8 for(int k=j;k<j+i;k++,w*=W){ 9 cmplx x=a[k]; 10 cmplx y=w*a[k+i]; 11 a[k]=x+y,a[k+i]=x-y; 12 } 13 } 14 } 15 if(dft==-1)for(int i=0;i<bit;i++)a[i]/=bit; 16 }
九. 總結
法法塔到這也就結束了,沒什么好說,再亮一波FFT代碼
洛谷 P3803 【模板】多項式乘法(FFT)
1 #include<iostream>
2 #include<cstdio>
3 #include<complex>
4 using namespace std; 5
6 const double PI=acos(-1); 7 typedef complex<double> cmplx; 8 cmplx a[2500005],b[2500005]; 9 int m,n,x,bit=2,rev[2500005]; 10 int output[2500005]; 11
12 void get_rev(){ 13 for(int i=0;i<bit;i++) 14 rev[i]=(rev[i>>1]>>1)|(bit>>1)*(i&1); 15 } 16
17 void FFT(cmplx *a,int dft){ 18 for(int i=0;i<bit;i++) 19 if(i<rev[i])swap(a[i],a[rev[i]]); 20 for(int i=1;i<bit;i<<=1){ 21 cmplx W=exp(cmplx(0,dft*PI/i)); 22 for(int j=0;j<bit;j+=i<<1){ 23 cmplx w(1,0); 24 for(int k=j;k<j+i;k++,w=w*W){ 25 cmplx x=a[k]; 26 cmplx y=w*a[k+i]; 27 a[k]=x+y,a[k+i]=x-y; 28 } 29 } 30 } 31 if(dft==-1) 32 for(int i=0;i<bit;i++)a[i]/=bit; 33 } 34
35 int main(){ 36 scanf("%d%d",&n,&m); 37 while(bit<=n+m)bit<<=1; 38 for(int i=0;i<=n;i++) 39 scanf("%d",&x),a[i]=x; 40 for(int i=0;i<=m;i++) 41 scanf("%d",&x),b[i]=x; 42 get_rev(); 43 FFT(a,1),FFT(b,1); 44 for(int i=0;i<bit;i++)a[i]*=b[i]; 45 FFT(a,-1); 46 for(int i=0;i<bit;i++) 47 output[i]+=a[i].real()+0.5; 48 for(int i=0;i<=n+m;i++) 49 printf("%d ",output[i]); 50 }