小學生都能看懂的FFT!!!
前言
在創新實踐中心偷偷看了一天FFT資料后,我終於看懂了一點。為了給大家提供一份簡單易懂的學習資料,同時也方便自己以后復習,我決定動手寫這份學習筆記。
食用指南:
本篇受眾:如標題所示,另外也面向同我一樣高中起步且非常菜的OIer。真正的dalao請無視。
本篇目標:讓大家(和不知道什么時候把FFT忘了的我)在沒有數學基礎的情況下,以最快的速度了解並 會寫 FFT。因此本篇將采用盡可能通俗易懂的語言,且略過大部分數學證明,在嚴謹性上可能有欠缺。但如果您發現了較大的邏輯漏洞,歡迎在評論里指正!最后……來個版權聲明吧。本文作者胡小兔,博客地址http://rabbithu.cnblogs.com。暫未許可在任何其他平台轉載。
你一定聽說過FFT,它的高逼格名字讓人望而卻步——“快速傅里葉變換”。
你可能知道它可以\(O(n \log n)\)求高精度乘法,你想學,可是面對一堆公式,你無從下手。
那么歡迎閱讀這篇教程!
[Warning] 本文涉及復數(虛數)的一小部分內容,這可能是最難的部分,但只要看下去也不是非常難,請不要看到它就中途退出啊QAQ。
什么是FFT?
快速傅里葉變換(FFT)是一種能在\(O(n \log n)\)的時間內將一個多項式轉換成它的點值表示的算法。
補充資料:什么是點值表示
設\(A(x)\)是一個\(n - 1\)次多項式,那么把\(n\)個不同的\(x\)代入,會得到\(n\)個\(y\)。這\(n\)對\((x, y)\)唯一確定了該多項式,即只有一個多項式能同時滿足“代入這些\(x\),得到的分別是這些\(y\)”。
由多項式可以求出其點值表示,而由點值表示也可以求出多項式。(並不想證明,十分想看證明的同學請前往“參考資料”部分)。
注:下文如果不加特殊說明,默認所有\(n\)為2的整數次冪。如果一個多項式次數不是2的整數次冪,可以在后面補0。
為什么要使用FFT?
FFT可以用來加速多項式乘法(平時非常常用的高精度大整數乘法就是最終把\(x = 10\)代入的多項式乘法)。
假設有兩個\(n-1\)次多項式\(A(x)\)和\(B(x)\),我們的目標是——把它們乘起來。
普通的多項式乘法是\(O(n^2)\)的——我們要枚舉\(A(x)\)中的每一項,分別與\(B(x)\)中的每一項相乘,來得到一個新的多項式\(C(x)\)。
但有趣的是,兩個用點值表示的多項式相乘,復雜度是\(O(n)\)的!具體方法:\(C(x_i) = A(x_i) \times B(x_i)\),所以\(O(n)\)枚舉\(x_i\)即可。
要是我們把兩個多項式轉換成點值表示,再相乘,再把新的點值表示轉換成多項式豈不就可以\(O(n)\)解決多項式乘法了!
……很遺憾,顯然,把多項式轉換成點值表示的朴素算法是\(O(n^2)\)的。另外,即使你可能不會——把點值表示轉換為多項式的朴素“插值算法”也是\(O(n^2)\)的。
難道大整數乘法就只能是\(O(n^2)\)嗎?!不甘心的同學可以發現,大整數乘法復雜度的瓶頸可能在“多項式轉換成點值表示”這一步(以及其反向操作),只要完成這一步就可以\(O(n)\)求答案了。如果能優化這一步,豈不美哉?
傅里葉:這個我會!
離散傅里葉變換(快速傅里葉變換的朴素版)
傅里葉發明了一種辦法:規定點值表示中的\(n\)個\(x\)為\(n\)個模長為\(1\)的復數。
——等等,先別看到復數就走!
補充資料:什么是復數
如果你學過復數,這段不用看了;
如果你學過向量,請把復數理解成一個向量;
如果你啥都沒學過,請把復數理解成一個平面直角坐標系上的點。復數具有一個實部和一個虛部,正如一個向量(或點)有一個橫坐標和一個縱坐標。例如復數\(3 + 2i\),實部是\(3\),虛部是\(2\),\(i = \sqrt{-1}\)。可以把它想象成向量\((3, 2)\)或點\((3, 2)\)。
但復數比一個向量或點更妙的地方在於——復數也是一種數,它可以像我們熟悉的實數那樣進行加減乘除等運算,還可以代入多項式\(A(x)\)——顯然你不能把一個向量或點作為\(x\)代入進去。
復數相乘的規則:模長相乘,幅角相加。模長就是這個向量的模長(或是這個點到原點的距離);幅角就是x軸正方向逆時針旋轉到與這個向量共線所途徑的角(或是原點出發、指向x軸正方向的射線逆時針旋轉至過這個點所經過的角)。想學會FFT,“模長相乘”暫時不需要了解過多,但“幅角相加”需要記住。
C++的STL提供了復數模板!
頭文件:#include <complex>
定義: complex<double> x;
運算:直接使用加減乘除。
傅里葉要用到的\(n\)個復數,不是隨機找的,而是——把單位圓(圓心為原點、1為半徑的圓)\(n\)等分,取這\(n\)個點(或點表示的向量)所表示的虛數,即分別以這\(n\)個點的橫坐標為實部、縱坐標為虛部,所構成的虛數。
從點\((1, 0)\)開始(顯然這個點是我們要取的點之一),逆時針將這\(n\)個點從\(0\)開始編號,第\(k\)個點對應的虛數記作\(\omega_n^k\)(根據復數相乘時模長相乘幅角相加可以看出,\(\omega_n^k\)是\(\omega_n^1\)的\(k\)次方,所以\(\omega_n^1\)被稱為\(n\)次單位根)。
根據每個復數的幅角,可以計算出所對應的點/向量。\(\omega_n^k\)對應的點/向量是\((\cos \frac{k}{n}2\pi, \sin \frac{k}{n}2\pi)\),也就是說這個復數是\(\cos \frac{k}{n}2\pi + i\sin \frac{k}{n}2\pi\)。
傅里葉說:把\(n\)個復數\(\omega_n^0, \omega_n^1, \omega_n^2, ..., \omega_n^{n-1}\)代入多項式,能得到一種特殊的點值表示,這種點值表示就叫離散傅里葉變換吧!
[Warning] 從現在開始,本文個別部分會集中出現數學公式,但是都不是很難,公式恐懼症患者請堅持!Stay Determined!
補充資料:單位根的性質
性質一:\(\omega_{2n}^{2k} = \omega_{n}^{k}\)
證明:它們對應的點/向量是相同的。
性質二:\(\omega_{n}^{k + \frac{n}{2}} = -\omega_{n}^{k}\)
證明:它們對應的點是關於原點對稱的(對應的向量是等大反向的)。
為什么要使用單位根作為\(x\)代入
當然是因為離散傅里葉變換有着特殊的性質啦。
[Warning] 下面有一些證明,如果不想看,請跳到加粗的“一個結論”部分。
設\((y_0, y_1, y_2, ..., y_{n - 1})\)為多項式\(A(x) = a_0 + a_1x + a_2x^2 +...+a_{n-1}x^{n-1}\)的離散傅里葉變換。
現在我們再設一個多項式\(B(x) = y_0 + y_1x + y_2x^2 +...+y_{n-1}x^{n-1}\),現在我們把上面的\(n\)個單位根的倒數,即\(\omega_{n}^{0}, \omega_{n}^{-1}, \omega_{n}^{-2}, ..., \omega_{n}^{-(n - 1)}\)作為\(x\)代入\(B(x)\), 得到一個新的離散傅里葉變換\((z_0, z_1, z_2, ..., z_{n - 1}\))。
這個\(\sum_{i = 0}^{n - 1}(\omega_n^{j - k})^i\)是可求的:當\(j - k = 0\)時,它等於\(n\); 其余時候,通過等比數列求和可知它等於\(\frac{(\omega_n^{j - k})^n - 1}{\omega_n^{j - k} - 1} = \frac{(\omega_n^n)^{j - k} - 1}{\omega_n^{j - k} - 1} = \frac{1^{j - k}- 1}{\omega_n^{j - k} - 1} = 0\)。
那么,\(z_k\)就等於\(na_k\), 即:
一個結論
把多項式\(A(x)\)的離散傅里葉變換結果作為另一個多項式\(B(x)\)的系數,取單位根的倒數即\(\omega_{n}^{0}, \omega_{n}^{-1}, \omega_{n}^{-2}, ..., \omega_{n}^{-(n - 1)}\)作為\(x\)代入\(B(x)\),得到的每個數再除以n,得到的是\(A(x)\)的各項系數。這實現了傅里葉變換的逆變換——把點值表示轉換成多項式系數表示,這就是離散傅里葉變換神奇的特殊性質。
快速傅里葉變換
雖然傅里葉發明了神奇的變換,能把多項式轉換成點值表示又轉換回來,但是……它仍然是暴力代入的做法,復雜度仍然是\(O(n^2)\)啊!(傅里葉:我都沒見過計算機,我干啥要優化復雜度……)
於是,快速傅里葉變換應運而生。它是一種分治的傅里葉變換。
[Warning] 下面有較多公式。看起來很嚇人,但是並不復雜。請堅持看完。
快速傅里葉變換的數學證明
仍然,我們設\(A(x) = a_0 + a_1x + a_2x^2 +...+a_{n-1}x^{n-1}\),現在為了求離散傅里葉變換,要把一個\(x = \omega_n^k\)代入。
考慮將\(A(x)\)的每一項按照下標的奇偶分成兩部分:
設兩個多項式:
則:
假設\(k < \frac{n}{2}\),現在要把\(x = \omega_n^k\)代入:
那么對於\(A(\omega_n^{k + \frac{n}{2}})\):
所以,如果我們知道兩個多項式\(A_1(x)\)和\(A_2(x)\)分別在\((\omega_{\frac{n}{2}}^{0}, \omega_{\frac{n}{2}}^{1}, \omega_{\frac{n}{2}}^{2}, ... , \omega_{\frac{n}{2}}^{\frac{n}{2} - 1}\))的點值表示,就可以\(O(n)\)求出\(A(x)\)在\(\omega_n^0, \omega_n^1, \omega_n^2, ..., \omega_n^{n-1}\)處的點值表示了。而\(A_1(x)\)和\(A_2(x)\)都是規模縮小了一半的子問題。分治邊界是\(n = 1\),此時直接return。
快速傅里葉變換的實現
寫個遞歸就可以實現一個FFT了!
cp omega(int n, int k){
return cp(cos(2 * PI * k / n), sin(2 * PI * k / n));
}
void fft(cp *a, int n, bool inv){
if(n == 1) return;
static cp buf[N];
int m = n / 2;
for(int i = 0; i < m; i++){ //將每一項按照奇偶分為兩組
buf[i] = a[2 * i];
buf[i + m] = a[2 * i + 1];
}
for(int i = 0; i < n; i++)
a[i] = buf[i];
fft(a, m, inv); //遞歸處理兩個子問題
fft(a + m, m, inv);
for(int i = 0; i < m; i++){ //枚舉x,計算A(x)
cp x = omega(n, i);
if(inv) x = conj(x);
//conj是一個自帶的求共軛復數的函數,精度較高。當復數模為1時,共軛復數等於倒數
buf[i] = a[i] + x * a[i + m]; //根據之前推出的結論計算
buf[i + m] = a[i] - x * a[i + m];
}
for(int i = 0; i < n; i++)
a[i] = buf[i];
}
inv表示這次用的單位根是否要取倒數。
至此你已經會寫fft了!但是這個fft還是1.0版本,比較慢(可能同時還比較長?),親測可能會比加了一些優化的fft慢了4倍左右……
那么我們來學習一些優化吧!
優化fft
非遞歸fft
在進行fft時,我們要把各個系數不斷分組並放到兩側,那么一個系數原來的位置和最終的位置有什么規律呢?
初始位置:0 1 2 3 4 5 6 7
第一輪后:0 2 4 6|1 3 5 7
第二輪后:0 4|2 6|1 5|3 7
第三輪后:0|4|2|6|1|5|3|7
“|”代表分組界限。
可以發現(這你都能發現?),一個位置a上的數,最后所在的位置是“a二進制翻轉得到的數”,例如6(011)最后到了3(110),1(001)最后到了4(100)。
那么我們可以據此寫出非遞歸版本fft:先把每個數放到最后的位置上,然后不斷向上還原,同時求出點值表示。
代碼:
cp a[N], b[N], omg[N], inv[N];
void init(){
for(int i = 0; i < n; i++){
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
}
}
void fft(cp *a, cp *omg){
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++){
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]); // i < t 的限制使得每對點只被交換一次(否則交換兩次相當於沒交換)
}
static cp buf[N];
for(int l = 2; l <= n; l *= 2){
int m = l / 2;
for(int j = 0; j < n; j += l)
for(int i = 0; i < m; i++){
buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
}
for(int j = 0; j < n; j++)
a[j] = buf[j];
}
}
可以預處理\(\omega_n^k\)和\(\omega_n^{-k}\),分別存在omg和inv數組中。調用fft時,如果無需取倒數,則傳入omg;如果需要取倒數,則傳入inv。
蝴蝶操作
這個優化有着一個高大上的名字——“蝴蝶操作”。我第一次看到這個名字時就嚇跑了——尤其是看到那種帶示意圖的蝴蝶操作解說時。
但是你完全無需跑!這是一個很簡單的優化,它可以丟掉上面代碼里的那個buf數組。
我們為什么需要buf數組?因為我們要做這兩件事:
a[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m]
但是我們又要求這兩行不能互相影響,所以我們需要buf數組。
但是如果我們這樣寫:
cp t = omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] - t
a[j + i] = a[j + i] + t
就可以原地進行了,不需要buf數組。
cp a[N], b[N], omg[N], inv[N];
void init(){
for(int i = 0; i < n; i++){
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
}
}
void fft(cp *a, cp *omg){
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++){
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]); // i < t 的限制使得每對點只被交換一次(否則交換兩次相當於沒交換)
}
for(int l = 2; l <= n; l *= 2){
int m = l / 2;
for(cp *p = a; p != a + n; p += l)
for(int i = 0; i < m; i++){
cp t = omg[n / l * i] * p[i + m];
p[i + m] = p[i] - t;
p[i] += t;
}
}
}
現在,這個fft就比之前的遞歸版快很多了!
到此為止我的FFT筆記就整理完啦。
下面貼一個FFT加速高精度乘法的代碼:
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <complex>
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
char c;
bool op = 0;
while(c = getchar(), c < '0' || c > '9')
if(c == '-') op = 1;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
}
template <class T>
void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
const int N = 1000005;
const double PI = acos(-1);
typedef complex <double> cp;
char sa[N], sb[N];
int n = 1, lena, lenb, res[N];
cp a[N], b[N], omg[N], inv[N];
void init(){
for(int i = 0; i < n; i++){
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
}
}
void fft(cp *a, cp *omg){
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++){
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]); // i < t 的限制使得每對點只被交換一次(否則交換兩次相當於沒交換)
}
for(int l = 2; l <= n; l *= 2){
int m = l / 2;
for(cp *p = a; p != a + n; p += l)
for(int i = 0; i < m; i++){
cp t = omg[n / l * i] * p[i + m];
p[i + m] = p[i] - t;
p[i] += t;
}
}
}
int main(){
scanf("%s%s", sa, sb);
lena = strlen(sa), lenb = strlen(sb);
while(n < lena + lenb) n *= 2;
for(int i = 0; i < lena; i++)
a[i].real(sa[lena - 1 - i] - '0');
for(int i = 0; i < lenb; i++)
b[i].real(sb[lenb - 1 - i] - '0');
init();
fft(a, omg);
fft(b, omg);
for(int i = 0; i < n; i++)
a[i] *= b[i];
fft(a, inv);
for(int i = 0; i < n; i++){
res[i] += floor(a[i].real() / n + 0.5);
res[i + 1] += res[i] / 10;
res[i] %= 10;
}
for(int i = res[lena + lenb - 1] ? lena + lenb - 1: lena + lenb - 2; i >= 0; i--)
putchar('0' + res[i]);
enter;
return 0;
}