最近認真研究了一下算法導論里面的多項式乘法的快速計算問題,主要是用到了FFT,自己也實現了一下,總結如下。
1.多項式乘法
兩個多項式相乘即為多項式乘法,例如:3*x^7+4*x^5+1*x^2+5與8*x^6+7*x^4+6*x^3+9兩個式子相乘,會得到一個最高次數項為13的多項式。一般來說,普通的計算方法是:把A多項式中的每一項與B中多項式中的每一項相乘,得到n個多項式,再把每個多項式相加到一起,得到最終的結果,不妨假設A,B的最高次項都為n-1,長度都為n,那么計算最終的結果需要o(n^2)時間復雜度。而使用快速傅里葉變換(FFT),則可以將時間復雜度降低到o(nlog n)。這是因為,對一個復數序列做正/反快速傅里葉變換的時間復雜度都是o(nlog n),而變換后的序列逐項相乘即為原序列做多項式乘法的結果(多項式乘法相當於卷積)。所以,FFT可以降低多項式相乘運算的時間復雜度,具體的解釋和證明在《算法導論》或者其他任何一本相關的算法書中都有詳細描述,在此不再贅述。另外,需要注意的是,一些其他的運算也可以轉化成多項式乘法,進而利用FFT來加快運算。例如:1.數字乘法運算,和多項式乘法類似,A*B的操作就是用A每一位上的數字乘以B每一位上的數字。尤其是在大數乘法中,FFT可以大幅度加快運算。2.給定A到B的不同長度路徑詳細數據,B到C的不同路徑長度詳細數據,求A到C不同長度路徑的數量。可以把A到B和B到C不同長度的路徑看成不同次數的項,例如:A到B有3條長度為4,2條長度為5的路徑,B到C有1條長度為2,4條長度為3的路徑,那么A到C不同長度路徑的數量等於(3*x^4+2*x^5)*(4*x^3+1*x^2)得到的各項的系數,轉化成多項式相乘問題之后,就可以利用FFT來加快運算速度了。
2.FFT
大多數人應該是只需要會用FFT即可,但是這個算法比較基礎,因此我自己編程實現了一下,總的代碼只有150行左右,其實不算長,當然,對輸入序列長度不是2的整數次冪這種情況我沒有相應的預處理,算是偷懶了。其實只要對長度取一下對數即可,例如:輸入長度如果是37,首先把37*2,拓展成74(這個是FFT必須的),然后對74取log2上取整即可,得到27=128,因此在74后面再添加54個0。
另外,需要注意的一點是,reverse函數在FFT和IFFT中是必須的,不過鑒於多項式乘法需要成對進行FFT和IFFT,所以在做多項式乘法的時候,reverse應該是可以省略的(當然,這個函數的耗時很小)。w的值應該提前計算出來,這樣在FFT和IFFT中蝶形計算每一項的時候,就不用重復計算w了,可以節省很多時間。
具體FFT的原理和解釋,可以查維基、信息論、數字信號處理、隨機過程等任一領域的教科書。
3.代碼
程序主要包括FFT,IFFT函數,以及一些復數運算
3.1復數的定義和相關運算定義
//復數 struct Complex{ double real; double image; }; Complex a1[MAX_SIZE],a2[MAX_SIZE],result[MAX_SIZE],w[MAX_SIZE]; //復數相乘計算 Complex operator*(Complex a,Complex b){ Complex r; r.real=a.real*b.real-a.image*b.image; r.image=a.real*b.image+a.image*b.real; return r; } //復數相加計算 Complex operator+(Complex a,Complex b){ Complex r; r.real=a.real+b.real; r.image=a.image+b.image; return r; } //復數相減計算 Complex operator-(Complex a,Complex b){ Complex r; r.real=a.real-b.real; r.image=a.image-b.image; return r; } //復數除法計算 Complex operator/(Complex a,double b){ Complex r; r.real=a.real/b; r.image=a.image/b; return r; } //復數虛部反計算 Complex operator~(Complex a){ Complex r; r.real=a.real; r.image=0-a.image; return r; }
3.2FFT和IFFT函數及相關函數
其實FFT和IFFT的原理一樣,只是IFFT多了一個除法步驟,也可以把兩個合並成一個函數。Reverse用於重新排列輸入數組的元素下標,例如輸入數組長度為8,則0,1,2,3,4,5,6,7下標的元素經過重新排列后變為0,4,2,6,1,5,3,7下標的元素。Compute_W用於預先計算FFT中需要的w值。
//重新排列方法2,效率較高 void Reverse(int* id,int size,int m){ for(int i=0;i<size;i++){ for(int j=0;j<(m+1)/2;j++){ int v1=(1<<(j)&i)<<(m-2*j-1); int v2=(1<<(m-j-1)&i)>>(m-2*j-1); id[i]|=(v1|v2); } } };
//重新排列方法1,該方法是用pow函數效率比較低 void Reverse(int* id,int size,int m){ for(int i=0;i<size;i++){ for(int j=0;j<m;j++){ int exp=(i>>j)&1; id[i]+=exp*(int)pow((double)2,(double)(m-j-1)); } } };
//計算並存儲需要乘的w值 void Compute_W(Complex w[],int size){ for(int i=0;i<size/2;i++){ w[i].real=cos(2*PI*i/size); w[i].image=sin(2*PI*i/size); w[i+size/2].real=0-w[i].real; w[i+size/2].image=0-w[i].image; } }; //快速傅里葉 void FFT(Complex in[],int size){ int* id=new int[size]; memset(id,0,sizeof(int)*size); int m=log((double)size)/log((double)2); Reverse(id,size,m); //將輸入重新排列,符合輸出 Complex *resort= new Complex[size]; memset(resort,0,sizeof(Complex)*size); int i,j,k,s; for(i=0;i<size;i++) resort[i]=in[id[i]]; for(i=1;i<=m;i++){ s=(int)pow((double)2,(double)i); for(j=0;j<size/s;j++){ for(k=j*s;k<j*s+s/2;k++){ Complex k1= resort[k]+w[size/s*(k-j*s)]*resort[k+s/2]; resort[k+s/2]=resort[k]-w[size/s*(k-j*s)]*resort[k+s/2]; resort[k]=k1; } } } for(i=0;i<size;i++) in[i]=resort[i]; delete[] id; delete[] resort; }; //快速逆傅里葉 void IFFT(Complex in[],int size){ int* id=new int[size]; memset(id,0,sizeof(int)*size); int m=log((double)size)/log((double)2); Reverse(id,size,m); //將輸入重新排列,符合輸出 Complex *resort= new Complex[size]; memset(resort,0,sizeof(Complex)*size); int i,j,k,s; for(i=0;i<size;i++) resort[i]=in[id[i]]; for(i=1;i<=m;i++){ s=(int)pow((double)2,(double)i); for(j=0;j<size/s;j++){ for(k=j*s;k<j*s+s/2;k++){ Complex k1=(resort[k]+(~w[size/s*(k-j*s)])*resort[k+s/2]); resort[k+s/2]=(resort[k]-(~w[size/s*(k-j*s)])*resort[k+s/2]); resort[k]=k1; } } } for(i=0;i<size;i++) in[i]=resort[i]/size; delete[] id; delete[] resort; };
3.3主函數
輸入兩個多項式的系數(長度必須都是2的整數次冪),輸出兩個多項式相乘的結果
int main(){ //輸入兩個多項式數列 int size,size1,size2,i; memset(a1,0,sizeof(a1)); memset(a2,0,sizeof(a2)); memset(w,0,sizeof(w)); memset(result,0,sizeof(result)); scanf("%d%d",&size1,&size2); for(i=0;i<size1;i++) scanf("%lf",&a1[i].real); for(i=0;i<size2;i++) scanf("%lf",&a2[i].real); size=size1>size2?size1*2:size2*2; Compute_W(w,size); FFT(a1,size); FFT(a2,size); for(i=0;i<size;i++) result[i]=a1[i]*a2[i]; IFFT(result,size); for(i=0;i<size1+size2-1;i++) printf("%.2lf ",result[i].real); printf("\n"); return 0; }
下面是完整的代碼