自己也看了幾篇博客,但是對我這種不擅長推導小白來說還是有一點困難,所以自己也寫一篇博客也為像我一樣的小白提供思路。以下內容包含各種LaTeX渲染,如果哪里有錯誤歡迎大家評論留言,或者添加本人qq:1403482164(無事勿擾)
一、FFT的應用場景
\(A(x) \text{=} a_0 \text{+} a_1x+a_2x^2+……+a_nx^n\)
\(B(x) \text{=} b_0 \text{+} b_1x+b_2x^2+……+b_mx^m\)
\(C(x) \text{=} A(x)\times B(x)\) 求解\(C\)的各項系數
二、前置知識
1. 復數
-
復數的表示形式:\(a+bi = r(cos\theta + isin\theta) = re^{\theta i}\),前兩種高一課程應該都學過,最后請參考歐拉公式
-
復數在平面直角坐標系上的表示:
主要參考是復數的第一,二種表示形式,實部(\(a / rcos\theta\))為橫坐標,虛部(\(b / risin\theta\))為縱坐標,也可以看成是半徑為\(r\)的一個圓上的某一個點,圓上的點就可以表示出長度為\(r\)的所有復數
2. 單位根
\(x = 1\),\(x\)有一個解為1,在平面直角坐標系中對應\((1,0)\)
\(x^2 = 1\),\(x\)有兩個解為\(1,\text{-}1\),在平面直角坐標系中對應\((1,0)\)和\((\text{-}1,0)\)
\(x^3 = 1\),\(x\)有三個解為\(1,\dfrac{\text{-}1\text{+}\sqrt{3}i}{2},\dfrac{\text{-}1\text{-}\sqrt{3}i}{2}\),在平面直角坐標系中對應\((1,0),(\text{-}\frac{1}{2},\frac{\sqrt{3}}{2}),(\text{-}\frac{1}{2},\text{-}\frac{\sqrt{3}}{2})\)
……
依次類推當\(w^k = 1\)時,\(w\)有\(k\)個解,且這些解在平面直角坐標系中表示出來后把半徑為1的圓平分成\(k\)等份,這樣的解叫做單位根,其中\(w_{n}^{k}\)表示\(w^n = 1\)的第\(k\)個解
- 單位根的性質
根據復數的表示形式\(w_{n}^{k} = cos\theta \text{+}isin\theta = re^{\theta i}\),其中\(\theta = \frac{2\pi k}{n}\)
通過單位根的形式變化和三角函數計算,可以推出他的一些性質:
-
\({(w_{n}^{k})}^2 = w_{n}^{2k} = w_{\frac{n}{2}}^{k}\)
-
\(w_{2n}^{2k} = w_{n}^{k}\)
-
\(w_{n}^{\frac{n}{2}+k} = \text{-}w_{n}^{k}\)
3. 矩陣乘法
給出\(A(x) = a_0+a_1x+……a_nx^n\)
\(A = \begin{bmatrix}{x_1}^0&{x_1}^1&\cdots&{x_1}^n\\{x_2}^0&{x_2}^1&\cdots&{x_2}^n\\\vdots&\vdots&\vdots&\vdots\\{x_n}^0&{x_n}^1&\cdots&{x_n}^n\end{bmatrix},B = \begin{bmatrix}a_0\\a_1\\\vdots\\a_n\end{bmatrix},C = \begin{bmatrix}A(x_1)\\A(x_2)\\\vdots\\A(x_n)\end{bmatrix}\)
三個矩陣的關系顯而易見:\(A\times B = C\),已知\(A,B\)的時候自然可求出\(C\),如果已知\(A,C\)如何求\(B\)呢?
我們引入了兩個新的名詞:單位矩陣和逆矩陣
-
單位矩陣\(I\):對角線為1,其余全為0,且滿足\(A\times I = A\)
-
逆矩陣:若\(A\times A^{-1} = I\)則稱\(A^{-1}\)為\(A\)的逆矩陣
在等式兩邊同時乘\(A^{-1}\)變成:\(A\times A^{-1}\times B = C\times A^{-1} \Rightarrow I\times B = C\times A^{-1} \Rightarrow B = C\times A^{-1}\)
這樣我們就可以求解\(B\)矩陣了。
4. 函數的表示方法:
-
系數法:已知所有項的系數當然可以確定一個函數
-
點值法:\(n\)項的函數,在平面直角坐標系中找\(n+1\)個點就可以確定這個函數
具體證明就不給了,畢竟oi大多數時候還是考感性理解的
三、FFT
講了個這么多,終於把前置知識處理完畢了,接下來就是正菜FFT了,不同於其他博客,可能我不會引用專業名詞像DFT等等……但是精華都是一樣的,相信大家都有超前的思考能力,發現某個問題不知道如何處理,耐心看下去,也許會發現不一樣的精彩。
1. 為什么用FFT,他如何優化問題
首先還是這個式子:\(A(x) = a_0+a_1x^2+……a_nx^n\)
當我給出一個\(x\),朴素求解的時間復雜度是多少呢? 循環疊加\(x\),\(O(n)\)求解,而我們需要\(n+1\)個點,那么確定一個函數的復雜度就是\(O(n^2)\)的
顯然不夠優秀是吧,而他的困難就在於對每個\(x\)求\(A(x)\)的過程,所以我們想借助某種方法把這個過程優化到\(O(logn)\),就發明了FFT
2. 具體流程
1. 先想辦法把式子處理一下:
奇偶分離(方便推導就假設\(n\)為偶數,實際寫代碼的時候,往后加0系數就可以了):
\(A_0(x) = a_0+a_2x^2+……a_nx^n,A_1(x) = a_1x+a_3x^3+……+a_{n+1}x^{n+1}\),其中\(a_{n+1} = 0\),為了讓兩個函數的項數相同
此時如果\(A_0,A_1,A\)三者的形式相同就好了,這樣就可以統一求解了
我們可以發現如果我從\(A_1(x)\)中提取一個\(x\)出來,那么他們之間的形式是一樣的,所以我讓\(xA_1(x) = a_1x+a_3x^3+a_{n+1}x^{n+1}\),那么\(A_1(x) = a_1 +a_3x^2+……+a_{n+1}x^n\)
那么現在\(A_0\)和\(A_1\)的形式相同,我們還需要把他們兩個和\(A\)統一形式,很簡單,我們把\(x\)替換成\(y = x^2\),此時\(A_0(y) = a_0+a_2y+a_4y^2+……+a_ny^{\frac{n}{2}}\),\(A_1(y) = a_1+a_3y+……+a_{n+1}y^{\frac{n}{2}}\)
到現在為止,我們已經把\(A_0\)和\(A_1\)從\(A\)中提取出來了,又要怎么合並起來呢?
其實這里面還有一個隱藏的關系:\(A(x) = A_0(x^2)+xA_1(x^2)\),往上面的式子里代一下,就可以得知了,這樣分成兩部分算,時間縮短了一半
2. 求解點值
巧妙的地方來了,除了式子的處理,我代入的點也十分的有講究。大家也能猜到,前置知識講的單位根肯定不是白講的吧。
把單位根\(w_{n}^{k}\)代入可得:
當\(k \leq \frac{n}{2}\)時,\(A(w_{n}^{k}) = A_0({w_{n}^{k}}^2)+w_{n}^{k}A_1({w_{n}^{k}}^2) = A_0(w_{\frac{n}{2}}^{k})+w_{n}^{k}A_1(w_{\frac{n}{2}}^{k})\)
當\(k \text{>} \frac{n}{2}\)時,把\(k+\frac{n}{2}\)代入,\(A(w_{n}^{k+\frac{n}{2}}) = A_0({w_{n}^{k+\frac{n}{2}}}^2)+w_{n}^{k+\frac{n}{2}}A_1({w_{n}^{k+\frac{n}{2}}}^2) = A_0(w_{\frac{n}{2}}^{k})-w_{n}^{k}A_1(w_{\frac{n}{2}}^{k})\)
我們發現兩個情況只有加減號不同,所以在求解前一半的時候,后一半同時可以求解,時間又縮小了一半,\(O(logn)\)的復雜度就出來了
3. 求解最終函數
算法進行到了最最最最后一步了,現在我們通過上述的一系列計算,已知了\(A(x)\)和\(B(x)\)的\(n+m+1\)個點,\(C(x) = A(x)*B(x)\),自然就得出了\(C(x)\)的\(n+m+1\)個點,這就要考察到我們前置內容中的矩陣乘法部分了。
我們已知\(A = \begin{bmatrix}{w_n^1}^0&{w_n^1}^1&\cdots&{w_n^1}^n\\{w_n^2}^0&{w_n^2}^1&\cdots&{w_n^2}^n\\\vdots&\vdots&\vdots&\vdots\\{w_n^n}^0&{w_n^n}^1&\cdots&{w_n^n}^n\end{bmatrix},C = \begin{bmatrix}C(w_n^1)\\C(w_n^2)\\\vdots\\C(w_n^n)\end{bmatrix}\),來求解\(B\)矩陣
所以問題轉化成求解\(A^{-1}\)矩陣,給出一個結論,\(A^{-1}\)矩陣就等於把\(w\)的所有上角標取負再乘\(\frac{1}{n}\)
以上就是FFT的全部流程,剩下還有動態規划優化,敬請期待下回詳解……
四、遞歸版代碼
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;
const double pi = acos(-1.0);const int maxn = 1e7+10;
int read(){
int x = 1,a = 0;char ch = getchar();
while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
return x*a;
}
struct node{
double x,y;
}a[maxn], b[maxn];
node operator + (node a,node b){return node{a.x+b.x,a.y+b.y};}
node operator - (node a,node b){return node{a.x-b.x,a.y-b.y};}
node operator * (node a,node b){return node{a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y};}
int n,m;
void fft(int len, node *a, int op){
if (len == 1) return;
node a0[(len>>1)+3],a1[(len>>1)+3];
for (int i = 0;i <= len;i += 2) a0[i>>1] = a[i],a1[i>>1] = a[i+1];
fft(len>>1,a0,op);fft(len>>1,a1,op);
node wn = node{cos(2*pi/len),op*sin(2*pi/len)},w0 = node{1,0};
for(int i = 0;i < (len >> 1);i++,w0 = w0*wn){
a[i] = a0[i]+w0*a1[i];
a[i+ (len>>1)] = a0[i]- w0*a1[i];
}
}
int main(){
//freopen("in.in","r",stdin);
//freopen("out.out","w",stdout);
n = read(),m = read();
for (int i = 0;i <= n;i++) a[i].x = read();
for (int i = 0;i <= m;i++) b[i].x = read();
int len = 1;
while (len <= n+m) len <<= 1;
fft(len,a,1);fft(len,b,1);
for (int i = 0;i <= len;i++) a[i] = a[i]*b[i];
fft(len,a,-1);
for (int i = 0;i <= n+m;i++) printf("%.0lf ",a[i].x/len);
return 0;
}
被WC摧殘的第一天,回來更新了……
首先來模擬一下遞歸的過程(盜用wiki):
規律: 其實就是原來的序列,每個數用二進制表示,然后把二進制翻轉對稱一下,就是最終那個位置的下標。比如\(x\)是001,翻轉是100,也就是4,即\(x\)最后的位置。我們稱這個變換為位逆序置換(bit-reversal permutation,國內也稱蝴蝶變換)。(把二進制寫出來就很好發現了,就是想不到寫二進制呀)
進而我們還可以發現,每一層對應的兩個數都是由他下面一層對應的兩個數更新而來,這樣問題就轉化成了求最后一層數的順序了,也就是求每個數的位逆序……
因為每個數都要求一遍,所以想能到dp轉移
狀態定義:\(dp_i\)表示\(i\)的位逆序
狀態轉移:\(dp_i = \frac{dp_{i/2}}{2} \text{+} (i\text{%}2)\times 2^{l-1}\)(自己模擬一下過程,還是挺顯而易見的)
五、非遞歸版代碼(話說luogu卡精度來着)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;
const double pi = acos(-1.0);const int maxn = 1e7+10;
int read(){
int x = 1,a = 0;char ch = getchar();
while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
return x*a;
}
struct node{
double x,y;
}a[maxn], b[maxn],a0[maxn],a1[maxn];
node operator + (node a,node b){node x = {a.x+b.x,a.y+b.y};return x;}
node operator - (node a,node b){node x = {a.x-b.x,a.y-b.y};return x;}
node operator * (node a,node b){node x = {a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y};return x;}
int n,m,dp[maxn];
double coss[maxn],sinn[maxn];
void fft(int len,node *a,int op){
for (int i = 0;i <= len;i++){
if (i < dp[i]) swap(a[i],a[dp[i]]);
}
for (int l = 1;l < len;l<<=1){
node wn = {coss[l],op*sinn[l]};
for (int i = 0;i < len;i+=(l << 1)){
node w0 = {1,0};
for (int j = 0;j < l;j++,w0 = w0*wn){
node x = a[i+j],y = w0*a[i+j+l];
a[i+j] = x+y,a[i+j+l] = x-y;
}
}
}
}
int main(){
//freopen("in.in","r",stdin);
//freopen("out.out","w",stdout);
n = read(),m = read();
for (int i = 0;i <= n;i++) a[i].x = read();
for (int i = 0;i <= m;i++) b[i].x = read();
int len = 1,num = 0;
while (len <= n+m) len <<= 1,num++;
for (int i = 0;i <= len;i++) dp[i] = (dp[i>>1]>>1)|((i&1)<<(num-1));
for (int i = 1;i <= len;i <<= 1) coss[i] = cos(pi/i),sinn[i] = sin(pi/i);
fft(len,a,1);fft(len,b,1);
for (int i = 0;i <= len;i++) a[i] = a[i]*b[i];
fft(len,a,-1);
for (int i = 0;i <= n+m;i++) printf("%d ",(int)(a[i].x/len+0.5));
return 0;
}