淺談FFT(快速傅里葉變換)


本文主要簡單寫寫自己在算法競賽中學習FFT的經歷以及一些自己的理解和想法。

FFT的介紹以及入門就不贅述了,網上有許多相關的資料,入門的話推薦這篇博客:FFT(最詳細最通俗的入門手冊),里面介紹得很詳細。

為什么要學習FFT呢?因為FFT能將多項式乘法的時間復雜度由朴素的$O(n^2)$降到$O(nlogn)$,這相當於能將任意形如$f[k]=\sum\limits _{i+j=k}f[i]\cdot f[j]$的轉移方程的計算在$O(nlogn)$的時間內完成。因此對於想要進階dp的同學來說,FFT是必須掌握的技能之一。(雖然在賽場上可能沒什么用武之地)

我學習FFT的過程也是比較曲折的,從接觸到真正理解它的原理前前后后經歷了半年的時間。(實際上我從去年接觸了FFT之后就一直把它當做一個黑盒算法來用,研究的事就扔到一邊了,只是偶爾簡單推算過幾次公式,直到這個月初才開始深入學習它的原理)

由於本人才疏學淺,所以自己的敘述若存在一些錯誤或者不足之處,敬請讀者指正。

 首先FFT的作用是什么?可以將多項式的系數表達式轉化成點值表達式(或者反過來,方法都是一樣的)。FFT(a,n)的作用是將多項式a(系數表達式)從$w_{n}^{0}$到$w_{n}^{n-1}$的所有根對應的取值求出來。也就是說,設$f(x)=\sum\limits_{i=0}^{n-1}a[i]\cdot x^i$,經過FFT變換后,a[i]變成了$f(w_{n}^{i})$。

這個利用單位根來表示的點值表達式的一個好處是如果已知FFT(a0,n/2)以及FFT(a1,n/2)(a0為a的偶數次項所構成的多項式,a1為a的奇數次項所構成的多項式),則根據性質$\left\{\begin{matrix}\begin{aligned}&a[i]=a_0[i]+w_{n}^{i}\cdot a_1[i]\\&a[i+\frac{n}{2}]=a_0[i]-w_{n}^{i}\cdot a_1[i]\end{aligned} \end{matrix}\right.$可以在$O(n)$的時間內算出數組a的值。

 為什么要用單位根呢?因為對於任意的數組長度n,在FFT的過程中使用單位根都只需要計算n個不同變量的值,與數組長度是線性相關的,而且一定能保證取到n個不同的值。而假如取2,3,4這樣的數的話,在對任意子數組進行FFT時仍需計算n個不同變量的值,這樣的話總的復雜度仍為$O(n^2)$,沒有絲毫降低。而假如取-1,1這樣的數,雖然只需要計算常數個變量的值了,但無論如何只能取到一兩個變量的值,也就是只能確定兩點,無法確定一個具有n個維度的多項式。

接下來就是代碼實現了。

首先我們做一下預處理:

1 typedef double db;
2 const db pi=acos(-1);

把double定義成db的作用,一是可以簡化代碼,二是需要調整精度的時候可以很方便地替換成其他變量類型,比如long double。

FFT的運算要用到復數,這就意味着我們必須找到一個能夠代表復數的變量類型。圖方便的話,C++庫中內置的complex類就夠用了。不過還是推薦自己寫一個結構體,比C++自帶的要快很多,而且也很好寫。

由於復數是一個二元組,和二維平面上的點非常類似,因此可以直接套用二維幾何中的點的結構體代碼。加減數乘等操作都完全一樣,只是多了個乘法。但這並不影響它的幾何意義,因為在計算幾何中兩向量乘法我一般喜歡用dot(點積)和cross(叉積)兩個函數來表示。此外,乘法運算符也可以表示坐標的旋轉。

復數(點)的結構體代碼如下:

1 struct P {
2     db x,y;
3     P operator+(const P& b) {return {x+b.x,y+b.y};}
4     P operator-(const P& b) {return {x-b.x,y-b.y};}
5     P operator*(const P& b) {return {x*b.x-y*b.y,x*b.y+y*b.x};}
6     P operator/(db b) {return {x/b,y/b};}
7 }

接下來就是FFT的實現了。有了FFT的基本概念和點的表示方法之后,我們不難寫出這樣的代碼:(f為1代表正變換(取值),f為1代表逆變換(插值))

 1 void FFT(P* a,int n,int f) {
 2     if(n==1)return;
 3     static P b[N];
 4     for(int i=0; i<n; i+=2)b[i/2]=a[i],b[(i+n)/2]=a[i+1];
 5     for(int i=0; i<n; ++i)a[i]=b[i];
 6     FFT(a,n/2,f),FFT(a+n/2,n/2,f);
 7     P wn= {cos(2*pi/n),f*sin(2*pi/n)},w= {1,0};
 8     for(int i=0; i<n/2; ++i,w=w*wn) {
 9         P x=a[i],y=w*a[i+n/2];
10         a[i]=x+y,a[i+n/2]=x-y;
11     }
12 }

可以看出,這個代碼是遞歸式的,其基本思想是將數組a分成兩部分,偶數次項放在左半邊,奇數次項放在右半邊,然后對左右兩部分分別遞歸做同樣的處理,最后把兩部分的答案合並,合並后a[0]-a[n-1]中的值分別為$f(w_{n}^{0})$-$f(w_{n}^{n-1})$的值。

但是遞歸在速度方面畢竟是硬傷,因此我們希望能將遞歸換成迭代的形式,這樣速度會快很多。

 通過觀察,我們不難發現,FFT的第一步總是將a[i]與a[i+n/2]合並,每個多項式相鄰兩項在數組中的距離為n(即只有一項),而最后一步總是將a[i]與a[i+1]合並,每個多項式相鄰兩項的距離為2,中間每合並一輪,距離減半。經過一番觀察和推理之后,我們可以得到如下改進后的代碼:

 1 void FFT(P* a,int n,int f) {
 2     static P b[N];
 3     P *A=a,*B=b;
 4     for(int k=n; k>=2; k>>=1,swap(A,B))
 5         for(int i=0; i<k>>1; ++i) {
 6             P wn= {cos(pi*k/n),f*sin(pi*k/n)},w= {1,0};
 7             for(int j=i; j<n; j+=k,w=w*wn) {
 8                 P x=A[j],y=w*A[j+(k>>1)];
 9                 B[((j-i)>>1)+i]=x+y,B[((j-i)>>1)+(n>>1)+i]=x-y;
10             }
11         }
12     if(A!=a)for(int i=0; i<n; ++i)a[i]=A[i];
13     if(!~f)for(int i=0; i<n; ++i)a[i]=a[i]/n;
14 }

這樣我們就成功地去掉了遞歸,換成了迭代實現的版本。中間使用了兩個指針A,B,是用乒乓效應減少數組的復制次數,有點類似倍增求后綴數組的方法。

但是這樣雖然去掉遞歸了,但仍需要$O(n)$的輔助空間,而且如果迭代次數為奇數次的話,最后還需要把變換后的數組復制回原數組,不太美觀。可以把輔助空間去掉,直接在原數組上進行合並嗎?

對於上述代碼,假設我們把x+y,x-y的值分別直接賦給a[((j-i)>>1)+i]和a[((j-i)>>1)+(n>>1)+i],那么原來這兩個位置上的信息就消失了,而這些信息在后面的合並中可能還需要用到,賦給其他位置也是同理。因此不能直接在原數組上進行賦值。這意味着,如果想直接在原數組上進行合並,合並后的兩個值和合並前的兩個值所存放的位置必須相同。例如,假如我們要合並下標分別為{0,4,8,12}和{2,6,10,14}的兩個數組,那么a[0]+w*a[2]的值必須放在a[0]或者a[2]的位置,a[0]-w*a[2]的值則必須放在另一個對應的位置。這樣一來,順序會變得很亂(自己試一試就知道了),因此若想在合並后不改變原數組中各項的位置,就必須在合並前把原數組“打亂”(當然不是隨便打亂,是對原數組進行一定規則的變換)。

如何“打亂”呢?我們可以把合並的過程倒過來觀察一下,這里借用一下網絡上的一張圖:

如圖所示,我們把“合並”的過程看成是倒過來“拆分”的過程,這是其中一種拆分的方法,可以發現,這種拆分的方法能保證“任意兩個位置上的數進行合並后的結果仍保存在它們各自的位置上,且合並后原數組的順序不變”,這樣就可以直接在原數組上進行合並了。

這種拆分方法有什么規律呢?同樣也可以發現,第一次拆分后,偶數次項都被分到了左邊,而奇數次項都分到了右邊。第二次拆分后,把每個項的次數都除以二(向下取整),得到的數為偶數的繼續被分到左邊,為奇數則被分到右邊,同理第三次拆分后要把每個項的次數除以4,第四次除以8......以此類推。從而我們可以總結出規律:設$n=2^t$,$rev(i)$為原數組的位置i拆分后對應的下標,$b(i,k)$為數字i的二進制第k位上的數(k∈{0,1}),利用按位累加的方法可以得到:

$b(rev(i),t-1-k)=\left\{\begin{matrix}\begin{aligned}0,b(i,k)=0\\ 1,b(i,k)=1\end{aligned}\end{matrix}\right.$

這相當於,每個數的拆分后的二進制第k位和原來的第t-1-k位是相同的,相當於把這個數的前t位二進制位進行了反轉。

如何利用數組拆分后,對應的下標二進制反轉的特性來對數組重排呢?一種比較普遍的方法是利用遞推的方法求出原數組反轉后的rev數組(方法不再敘述,網絡上一搜便知),再從前往后掃一遍原數組,遇到rev數組中對應的元素比它小的情況,就交換一次。這種方法的時間復雜度是$O(n)$的,但仍需要$O(n)$的輔助空間,而且對於不同的n要重新求一遍rev數組,比較麻煩。直到我找到了這樣的一段代碼:

1 void change(P* a,int n) {
2     for(int i=1,j=n>>1,k; i<n-1; ++i) {
3         if(i<j)swap(a[i],a[j]);
4         k=n>>1;
5         while(j>=k)j-=k,k>>=1;
6         j+=k;
7     }
8 }

這段代碼打眼一看可能會有點懵逼,這是在干嘛?其實自己模擬一下便知,這是在對一個數組“暴力”進行反轉,方法是模擬“倒過來加”的過程,把左起第一個0變成1,把前面的1都變成0,這樣倒過來看就好像是整個數加了1,從頭到尾掃一遍就行了。甚至可以改寫成位運算的形式:

1 void change(P* a,int n) {
2     for(int i=1,j=n>>1,k; i<n-1; ++i,j^=k) {
3         if(i<j)swap(a[i],a[j]);
4         for(k=n>>1; j&k; j^=k,k>>=1);
5     }
6 }

這樣一來,FFT的空間消耗就徹底變成$O(1)$了。但是還有一個問題,就是這個函數的時間復雜度是多少呢?

可以看出,這個函數的時間復雜度主要取決於k的移動次數。不考慮邊界情況的話,假如j的第一個0在第n-1位,那么k只需要移動一次(賦值成n/2),這樣的情況一共有n/2種;假如第一個0在第n-2位,那么k需要移動兩次,這樣的情況一共有n/4種...以此類推。最壞的情況是第一個0在第0位,此時需要移動logn次,但這只有一種情況。

因此,假設$n=2^t$,則函數中的k總共需要移動$\sum\limits_{i=1}^ti\cdot 2^{t-i}$次。

這個式子怎么算呢?

我們考慮等比級數$\sum\limits_{i=1}^tx^{t-i+1}=\frac{x(1-x^t)}{1-x}$

等式兩邊求導得$\sum\limits_{i=1}^t(t-i+1)x^{t-i}=\frac{1-(t+1)x^t+tx^{t+1}}{(1-x)^2}$

又有$\sum\limits_{i=1}^t(t-i+1)x^{t-i}=(t+1)\sum\limits_{i=1}^tx^{t-i}-\sum\limits_{i=1}^tix^{t-i}$

即$\sum\limits_{i=1}^tix^{t-i}=(t+1)\sum\limits_{i=1}^tx^{t-i}-\sum\limits_{i=1}^t(t-i+1)x^{t-i}=\frac{(t+1)(1-x^t)}{1-x}-\frac{1-(t+1)x^t+tx^{t+1}}{(1-x)^2}$

將x=2代入得$\sum\limits_{i=1}^ti\cdot 2^{t-i}=\frac{(t+1)(1-2^t)}{1-2}-\frac{1-(t+1)2^t+t2^{t+1}}{(1-2)^2}$

化簡得$\sum\limits_{i=1}^ti\cdot 2^{t-i}=2^{t+1}-t-2=2n-logn-2=O(n)$

對,你沒有看錯,空間復雜度降到了$O(1)$,而時間復雜度仍為$O(n)$,刺不刺激?

經過多次優化,可以最終得到了如下的FFT代碼:

 1 void FFT(P* a,int n,int f) {
 2     for(int i=1,j=n>>1,k; i<n-1; ++i,j^=k) {
 3         if(i<j)swap(a[i],a[j]);
 4         for(k=n>>1; j&k; j^=k,k>>=1);
 5     }
 6     for(int k=1; k<n; k<<=1) {
 7         P wn= {cos(pi/k),f*sin(pi/k)};
 8         for(int i=0; i<n; i+=k<<1) {
 9             P w= {1,0};
10             for(int j=i; j<i+k; ++j,w=w*wn) {
11                 P x=a[j],y=w*a[j+k];
12                 a[j]=x+y,a[j+k]=x-y;
13             }
14         }
15     }
16     if(!~f)for(int i=0; i<n; ++i)a[i]=a[i]/n;
17 }

非遞歸實現,$O(nlogn)$的時間復雜度,O(1)的空間復雜度,既保證了效率又簡潔了代碼,豈不美哉?

有了FFT的代碼,就可以實現多項式乘法了。用FFT實現多項式乘法的一般步驟是將被乘的兩個多項式分別用FFT轉化成點值表達式,然后對應位相乘,最后再用FFT逆變換轉化回來就行了。

值得注意的是,對被乘的兩個多項式進行FFT時,數組長度至少應大於兩個多項式的最高次數之和,否則會出現莫名其妙的錯誤。又因為數組長度必須是2的t次方的形式,保險起見最好開到多項式相乘后的最高次數的兩倍或以上。

最后推薦幾道FFT的練習題:

HDU - 4609 3-idiots

UVA - 12298 Super Poker II

Gym - 101002E K-Inversions

Gym - 101667H Rock Paper Scissors

HDU - 1402 A * B Problem Plus

Gym - 101234D Forest Game

順便附上UVA - 12298的完整代碼:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 typedef long double db;
 5 const int N=2e5+10;
 6 const db pi=acos(-1);
 7 struct P {
 8     db x,y;
 9     P operator+(const P& b) {return {x+b.x,y+b.y};}
10     P operator-(const P& b) {return {x-b.x,y-b.y};}
11     P operator*(const P& b) {return {x*b.x-y*b.y,x*b.y+y*b.x};}
12     P operator/(db b) {return {x/b,y/b};}
13 } p[4][N];
14 void FFT(P* a,int n,int f) {
15     for(int i=1,j=n>>1,k; i<n-1; ++i,j^=k) {
16         if(i<j)swap(a[i],a[j]);
17         for(k=n>>1; j&k; j^=k,k>>=1);
18     }
19     for(int k=1; k<n; k<<=1) {
20         P wn= {cos(pi/k),f*sin(pi/k)};
21         for(int i=0; i<n; i+=k<<1) {
22             P w= {1,0};
23             for(int j=i; j<i+k; ++j,w=w*wn) {
24                 P x=a[j],y=w*a[j+k];
25                 a[j]=x+y,a[j+k]=x-y;
26             }
27         }
28     }
29     if(!~f)for(int i=0; i<n; ++i)a[i]=a[i]/n;
30 }
31 int com[N],a,b,c;
32 
33 int main() {
34     memset(com,0,sizeof com);
35     for(int i=2; i<N; ++i)if(!com[i])for(int j=i*2; j<N; j+=i)com[j]=1;
36     while(scanf("%d%d%d",&a,&b,&c)&&(a||b||c)) {
37         int m;
38         for(m=1; m<=b*2; m<<=1);
39         for(int f=0; f<4; ++f)fill(p[f],p[f]+m,(P) {0,0});
40         while(c--) {
41             int x;
42             char ch;
43             scanf("%d%c",&x,&ch);
44             if(x>b)continue;
45             if(ch=='S')p[0][x].x--;
46             else if(ch=='H')p[1][x].x--;
47             else if(ch=='C')p[2][x].x--;
48             else if(ch=='D')p[3][x].x--;
49         }
50         for(int f=0; f<4; ++f)
51             for(int i=0; i<=b; ++i)p[f][i].x+=com[i];
52         for(int i=1; i<4; ++i)FFT(p[i],m,1);
53         for(int f=1; f<4; ++f) {
54             FFT(p[0],m,1);
55             for(int i=0; i<m; ++i)p[0][i]=p[0][i]*p[f][i];
56             FFT(p[0],m,-1);
57             for(int i=b+1; i<m; ++i)p[0][i]= {0,0};
58         }
59         for(int i=a; i<=b; ++i)printf("%lld\n",ll(p[0][i].x+0.5));
60         puts("");
61     }
62     return 0;
63 }

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM