FFT 的優化和任意模數 FFT
1. 前言和前置技能
這篇主要講卡常如何卡到,以及任意模數的 FFT 。uoj
榜第二頁
雖然當時最開始學的是 \(FFT\),然而似乎沒什么人用,因為多項式全家桶里用的都是 \(NTT\)。而且 \(FFT\) 跑的也並不是很快,所以大概沒什么人學吧。
然而我們會發現uoj
榜前面的都是 \(FFT\),因為毛爺爺(毛嘯)的論文里介紹了很多卡(qi)常(ji)技(yin)巧(qiao)。實際效果非常顯著,速度是普通 \(FFT\) 的 \(2\) 到 \(5\) 倍不等。
\(FFT\) 有個巨大的先天優勢:運算都是在復數域上的。雖然會出現各種精度誤差以及常數巨大,但是 \(NTT\) 要是想借此來嘲笑 \(FFT\) 也不過是五十步笑百步罷了。
\(NTT\) 最大的優勢在於運算都是在模意義下進行的,這就可以實現各種模意義下的操作。然而 \(NTT\) 最大的劣勢也在於模意義,於是當模數不是 \(NTT\) 模數的時候就會當場去世。說到底,模意義下的運算總是束手束腳的,為什么我們不走向連續的更加廣闊的復數域呢?
不知道大家有沒有發現一個細節,\(FFT\) 在 \(DFT\) 之前和 \(IDFT\) 之后各系數的虛部都是 \(0\) 。這就造成了巨大的浪費。我們的 \(FFT\) 本身就能對復數域進行 \(DFT\) ,為什么我們不能利用一下這個神奇的性質呢?
你需要的技能有:
- \(FFT\)
- 一些關於復數的小知識
回顧一下 \(DFT\) 干了什么:
對於一個多項式 \(f(x) = \sum\limits_{j=0}^na_jx^j\) ,設 \(N\) 表示延長至 \(2\) 的整次冪的長度 ,\(X\) 表示 \(DFT\) 之后的結果, \(w_N\) 表示 \(N\) 次單位復根,則有:
$\bar z $ 表示 $ z $ 的共軛。復數 \(a + bi\) 的共軛是 \(a - bi\) 。\(\overline {z_iz_j} = \bar z_i \bar z_j\)
然后以下的邊界可能會出一些問題,我們默認 \(N - 0\) 和 \(0\) 是同一個位置,但是大於等於 \(N\) 的位置系數都是 \(0\)。
2. 合並
\(DFT\) 其實是個挺對稱的過程,因為 \(w_N^k = -w_N^{\frac N 2+k}\) 。我們來比較一下 \(X_i\) 和 \(X_{N-i}\) :
好像。。。。十分有理有據令人信服。我們求出 \(X_i\) 就能知道 \(X_{N-i}\) 了。那我們用 \(N\) 的長度做難道不是很浪費嗎?
為什么 \(X_j = \overline{ X_{N-j}}\) 呢?根本原因是我們的多項式系數沒有虛部。我們換一個有虛部的試一下。設
然后 \(DFT\) 一下, 得到
如何是好?
然鵝我們對比一下 \(\overline {X_{N-j}}\) 和 \(\overline{X_j}\) ,我們會發現
然后我們再用 \(X_j\) 加減 \(\overline{X_{N-j}}\),會得到:
而且對於兩個DFT \(X_j, Y_j\) ,我們可以通過同樣的方法 \(a_j = \frac 1 N \sum\limits_{k=0}^{N-1}(X_k + iY_k)w_N^{-jk}\)來得到 \(X_k, Y_k\) 的 IDFT 。
3. 分裂
一般來說關於 \(FFT\) 的常數指的是 \(DFT/IDFT\) 的次數。
我們先用上面的技巧優化多項式乘法:
現在我們要將兩個多項式 \(f(x) = \sum\limits_{j=0}^na_jx^j\) 和 \(g(x) = \sum\limits_{j=0}^mb_jx^j\) 卷起來得到 \(h(x) = \sum\limits_{j=0}^mc_jx^j\) 。正常的做法,我們需要 \(3\) 次兩倍長度的 \(FFT\)。
我們先把它分奇偶拆成兩個多項式,然后按照剛剛的合並方法,得到
記 \(X, Y, Z\) 為 \(f, g, h\) 的 \(DFT\) ,\(N\) 為單倍長度的 \(FFT\) 長度。
我們發現 \(Z_j\) 和 \(X_jY_j\) 只有一項不同,我們減一下:
(注意 \((18)\) 式的前半部分 \(k\) 不能取到 \(0\) ,所以我們換一下求和指標即可)。
然后又因為
繼續化簡得到
於是乎
這樣做時空常數都是普通寫法的一半。
4. 任意模數FFT
黑科技來了。
在任意模數下,系數有可能是 \(1e9\) 級別的,而double
的精度在 \(1e14\) 級別,於是直接 \(FFT\) 我們就炸精了。雖然我們有三模數 \(NTT\),但是那玩意,嗯,我們今天不講。
本來想說那玩意慢死了,然后覺得 FFT 笑 NTT 也不過是五十步笑百步
但是需要大量卷積的題這個玩意就是快啊
我們將某個系數 \(a_j\) 表示成 \({a_{0j}}M+a_{1j}\),然后拆成兩個系數小於 \(M\) 的多項式 \(A_0(x)\) 和 \(A_1(x)\) ,於是我們就能直接 FFT 了。一般 \(M\) 取 \(32768\) 比較方便。
這樣做需要 \(8\) 次 \(FFT\)。比起 \(NTT\) 的 \(9\) 遍並沒有快到哪兒去。(話說 \(NTT\) 也能這么拆吧?)
但是上面不是講了優化技巧了嗎?我們的虛部還是空的。
然后我們就能大力拆式子啦!
我們最后要求的式子是 \(A_1B_1 + M(A_1B_0+A_0B_1)+M^2A_0B_0\)。
我們令多項式 \(F(x), G(x)\) 的各項系數為
\(X, Y\) 表示 \(F(x), G(x)\) 的 \(DFT\),則
我們此時令 \(\mathbf {A, B, C, D}\) 來表示 \(A_0, A_1, B_0, B_1\) 的 \(DFT\),於是我們有
除 \(i\) 怎么辦呢?復數除法超級麻煩的。但是我們注意到
於是變成乘法就好了。
然后接下來,我們再用類似的技巧, 構造兩個新的 \(DFT\) ,\(\mathbf {P,Q}\)。
然后 \(IDFT\) 回去,最后的答案就是 \(\Re (\mathbf P_j) + M(\Im(\mathbf P_j)+\Re(\mathbf Q_j))+M^2\Im(\mathbf Q_j)\)。
然而實際應用的時候我們還能再化簡一步。
同理
於是我們的常數更小了。
盡管如此,它還是跑不過普通的 \(NTT\),目前大概是我的 \(NTT\) 的時間的 \(1.5\) 倍吧,因為做了 \(4\) 次兩倍長度的 \(FFT\) 。
總體來說常數因子的影響還是很明顯的,最上面那個常數是正常寫法的二分之一,時間也大概在二分之一左右,而下面的寫法常數是三分之四,時間也是三分之四,總之 \(FFT\) 次數對代碼的運行時間有着重大的影響。
注意這個算法對精度要求很高,很容易炸精度。
下面的代碼里有一種處理單位根的方式,能保證每個數最多被乘 \(\log n\) 次,保證精度的時候常數也小。
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cctype>
#include <algorithm>
const int LEN = 1 << 20 | 1;
char bufin[LEN];
char bufout[LEN];
char *Rd = bufin, *Wt = bufout;
#define getchar() (*Rd++)
#define putchar(x) (*Wt++ = x)
inline int read() {
register int ret, cc;
while (!isdigit(cc = getchar())){}ret = cc-48;
while ( isdigit(cc = getchar())) ret = cc-48+ret*10;
return ret;
}
inline void write(int x, char ch = '\n') {
register int stk[20], tp;
stk[tp = !x] = 0;
while (x) stk[++tp] = x % 10, x /= 10;
while (tp) putchar(stk[tp--] + '0');
putchar(ch);
}
struct Complex {
double x, y;
Complex(double x = 0, double y = 0)
:x(x), y(y) { }
inline Complex operator + (const Complex& rhs) const { return Complex(x + rhs.x, y + rhs.y); }
inline Complex operator - (const Complex& rhs) const { return Complex(x - rhs.x, y - rhs.y); }
inline Complex operator * (const Complex& rhs) const { return Complex(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x); }
inline Complex operator - () const { return Complex(-x, -y); }
inline Complex operator ! () const { return Complex( x, -y); }
void print() {
printf("(%f, %f)\n", x, y);
}
};
const int MAXN = 131080;
const double PI = acos(-1.0);
inline void FFT(Complex*, int, int);
inline int getrev(int);
int rev[MAXN];
int N, M;
Complex A[MAXN];
Complex B[MAXN];
Complex C[MAXN];
Complex W[MAXN];
int tmp[MAXN];
int main() {
#ifndef ONLINE_JUDGE
freopen("test.in", "r", stdin);
#endif
fread(bufin, 1, LEN, stdin);
N = read() + 1, M = read() + 1;
for (int i = 0; i < N; ++i) (i & 1 ? A[i >> 1].y : A[i >> 1].x) = read();
for (int i = 0; i < M; ++i) (i & 1 ? B[i >> 1].y : B[i >> 1].x) = read();
int len = N + M - 1, bln = getrev((len+1)>>1);
FFT(A, bln, 1), FFT(B, bln, 1);
Complex w(1, 0);
for (int i = 0, j; i < bln; ++i, w = w * W[1]) {
j = (bln-i)&(bln-1);
C[i] = A[i]*B[i]-((A[i]-!A[j])*(B[i]-!B[j])*(w+1))*0.25;
}
FFT(C, bln, -1);
for (int i = 0; i < len; ++i) write(((i & 1) ? C[i >> 1].y : C[i >> 1].x) + 0.5, ' ');
putchar('\n');
fwrite(bufout, 1, Wt - bufout, stdout);
}
inline int getrev(int n) {
int bln = 1, bct = 0;
while (bln < n) bln <<= 1, bct++;
for (int i = 0; i < bln; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bct - 1));
return bln;
}
inline void FFT(Complex* a, int n, int opt) {
for (int i = 0; i < n; ++i)
if (i < rev[i]) std::swap(a[i], a[rev[i]]);
W[0] = 1;
for (int i = 1; i < n; i <<= 1) {
Complex wn(cos(PI / i), opt * sin(PI / i));
for (int k = i - 2; k >= 0; k -= 2)
W[k + 1] = (W[k] = W[k >> 1]) * wn;
for (int j = 0, p = (i << 1); j < n; j += p) {
Complex *x = a + j, *y = a + j + i, *w = W;
for (int k = 0; k < i; ++k, ++x, ++y, ++w) {
Complex tmp = *w * *y;
*y = *x - tmp, *x = *x + tmp;
}
}
}
if (opt == -1) for (int i = 0; i < n; ++i) a[i].x /= n, a[i].y /= n;
}
外加一個全家桶。
#include <bits/stdc++.h>
inline int read() {
int ret, cc, sign = 1;
while (!isdigit(cc = getchar()))
sign = cc == '-' ? -1 : sign;
ret = cc - 48;
while (isdigit(cc = getchar()))
ret = cc - 48 + ret * 10;
return ret * sign;
}
const int MOD = 998244353;
const int G = 3;
const int MAXN = 600010;
typedef std::vector<int> Poly;
typedef long long i64;
inline int add(int a, int b) { return (a += b) >= MOD ? a - MOD : a; }
inline int sub(int a, int b) { return (a -= b) < 0 ? a + MOD : a; }
inline int mul(int a, int b) { return 1ll * a * b % MOD; }
inline int qpow(int a, int p) {
int ret = 1;
for (p += (p < 0) * (MOD - 1); p; p >>= 1, a = mul(a, a))
if (p & 1) ret = mul(ret, a);
return ret;
}
inline Poly operator + (const Poly&, const Poly&);
inline Poly operator - (const Poly&, const Poly&);
inline Poly operator * (const Poly&, const Poly&);
inline Poly Inverse(const Poly&);
inline Poly Integrate(const Poly&);
inline Poly Derivate(const Poly&);
inline Poly Ln(const Poly&);
inline Poly Exp(const Poly&);
inline Poly Pow(const Poly&, int);
inline Poly Sqrt(const Poly&);
inline void Read(Poly&);
inline void Print(const Poly&);
int main() {
#ifdef ARK
freopen("test.in", "r", stdin);
#endif
Poly A(read()), One(1, 1);
int K = read();
Read(A);
Poly ans = Derivate(Pow(One + Ln(One +
Inverse(Exp(Integrate(Inverse(Sqrt(A)))))), K));
ans[ans.size() - 1] = 0;
Print(ans);
}
inline void Read(Poly& A) {
for (auto &x : A)
x = read();
}
inline void Print(const Poly& A) {
for (auto x : A)
printf("%d ", x);
puts("");
}
int rev[MAXN];
int W[MAXN];
inline int getrev(int n) {
int len = 1, cnt = 0;
while (len < n) len <<= 1, ++cnt;
for (int i = 0; i < len; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
return len;
}
struct Complex {
double x, y;
Complex(double x = 0, double y = 0) : x(x), y(y) {}
inline Complex operator+(const Complex& rhs) const { return Complex(x + rhs.x, y + rhs.y); }
inline Complex operator-(const Complex& rhs) const { return Complex(x - rhs.x, y - rhs.y); }
inline Complex operator*(const Complex& rhs) const {
return Complex(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x);
}
};
Complex CW[MAXN];
const double Pi = acos(-1.0);
void FFT(Complex* a, int n, int opt) {
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
CW[0] = 1;
for (int i = 1; i < n; i <<= 1) {
Complex wn(cos(Pi / i), opt * sin(Pi / i));
for (int k = i - 2; k >= 0; k -= 2) CW[k + 1] = (CW[k] = CW[k >> 1]) * wn;
for (int j = 0, p = i << 1; j < n; j += p) {
for (int k = 0; k < i; ++k) {
Complex x = a[j + k], y = CW[k] * a[j + k + i];
a[j + k] = x + y, a[j + k + i] = x - y;
}
}
}
if (opt == -1)
for (int i = 0; i < n; ++i) a[i].x /= n, a[i].y /= n;
}
inline Poly Conv(const Poly& A, const Poly& B) {
static Complex a[MAXN], b[MAXN];
static Complex c[MAXN], d[MAXN];
int len = A.size() + B.size() - 1, bln = getrev(len);
for (size_t i = 0; i < A.size(); ++i) a[i] = Complex(A[i] & 32767, A[i] >> 15);
for (int i = A.size(); i < bln; ++i) a[i] = 0;
for (size_t i = 0; i < B.size(); ++i) b[i] = Complex(B[i] & 32767, B[i] >> 15);
for (int i = B.size(); i < bln; ++i) b[i] = 0;
FFT(a, bln, 1), FFT(b, bln, 1);
for (int i = 0, j; i < bln; ++i) {
j = (bln - i) & (bln - 1);
c[i] = Complex(a[i].x + a[j].x, a[i].y - a[j].y) * 0.5 * b[i];
d[i] = Complex(a[i].y + a[j].y, a[j].x - a[i].x) * 0.5 * b[i];
}
FFT(c, bln, -1), FFT(d, bln, -1);
Poly C(len);
for (int i = 0; i < len; ++i) {
i64 u = ((i64)(c[i].x + 0.5)) % MOD, v = ((i64)(c[i].y + 0.5)) % MOD;
i64 x = ((i64)(d[i].x + 0.5)) % MOD, y = ((i64)(d[i].y + 0.5)) % MOD;
C[i] = (u + ((v + x) << 15) % MOD + (y << 30) % MOD) % MOD;
}
return C;
}
inline void NTT(Poly& a, int n, int opt) {
a.resize(n);
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
for (int i = (W[0] = 1); i < n; i <<= 1) {
int wn = qpow(G, opt * (MOD - 1) / (i << 1));
for (int k = i - 2; k >= 0; k -= 2)
W[k + 1] = mul(W[k] = W[k >> 1], wn);
for (int j = 0, p = i << 1; j < n; j += p) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = mul(W[k], a[j + k + i]);
a[j + k] = add(x, y), a[j + k + i] = sub(x, y);
}
}
}
if (opt == -1)
for (int i = 0, r = qpow(n, MOD - 2); i < n; ++i)
a[i] = mul(a[i], r);
}
inline Poly operator + (const Poly& lhs, const Poly& rhs) {
Poly ret(std::max(lhs.size(), rhs.size()));
for (size_t i = 0; i < ret.size(); ++i)
ret[i] = add(i < lhs.size() ? lhs[i] : 0, i < rhs.size() ? rhs[i] : 0);
return ret;
}
inline Poly operator - (const Poly& lhs, const Poly& rhs) {
Poly ret(std::max(lhs.size(), rhs.size()));
for (size_t i = 0; i < ret.size(); ++i)
ret[i] = sub(i < lhs.size() ? lhs[i] : 0, i < rhs.size() ? rhs[i] : 0);
return ret;
}
//inline Poly operator * (const Poly& lhs, const Poly& rhs) {
// Poly A(lhs), B(rhs);
// int len = A.size() + B.size() - 1;
// int bln = getrev(len);
// NTT(A, bln, 1), NTT(B, bln, 1);
// for (int i = 0; i < bln; ++i)
// A[i] = mul(A[i], B[i]);
// NTT(A, bln, -1), A.resize(len);
// return A;
//}
inline Poly operator * (const Poly& lhs, const Poly& rhs) {
return Conv(lhs, rhs);
}
inline Poly Inverse(const Poly& A) {
Poly B(1, qpow(A[0], MOD - 2));
int n = A.size() << 1;
for (int i = 2; i < n; i <<= 1) {
Poly C(A);
C.resize(i);
//int len = getrev(i << 1);
//NTT(C, len, 1), NTT(B, len, 1);
C = Poly(1, 2) - B * C;
C.resize(i);
B = B * C;
B.resize(i);
//for (int j = 0; j < len; ++j)
// B[j] = mul(B[j], sub(2, mul(B[j], C[j])));
//NTT(B, len, -1), B.resize(i);
}
B.resize(A.size());
return B;
}
inline std::vector<int> getinv() {
std::vector<int> ret(MAXN);
ret[1] = 1;
for (int i = 2; i < MAXN; ++i)
ret[i] = mul(MOD - MOD / i, ret[MOD % i]);
return ret;
}
std::vector<int> Inv = getinv();
inline Poly Integrate(const Poly& A) {
Poly C(A.size() + 1);
for (size_t i = 1; i < C.size(); ++i)
C[i] = mul(Inv[i], A[i - 1]);
return C;
}
inline Poly Derivate(const Poly& A) {
Poly C(A.size() - 1);
for (size_t i = 0; i < C.size(); ++i)
C[i] = mul(i + 1, A[i + 1]);
return C;
}
inline Poly Ln(const Poly& A) {
Poly C = Integrate(Derivate(A) * Inverse(A));
C.resize(A.size());
return C;
}
inline Poly Exp(const Poly& A) {
Poly B(1, 1);
int n = A.size() << 1;
for (int i = 2; i < n; i <<= 1) {
Poly C(A);
C.resize(i), B.resize(i);
B = B * (Poly(1, 1) - Ln(B) + C);
}
B.resize(A.size());
return B;
}
inline Poly Pow(const Poly& A, int k) {
Poly C(Ln(A));
for (auto &x : C)
x = mul(x, k);
return Exp(C);
}
inline Poly Sqrt(const Poly& A) {
Poly C(A);
int c = A[0], ic = qpow(c, MOD - 2);
for (auto &x : C)
x = mul(x, ic);
c = sqrt(c), C = Pow(C, Inv[2]);
for (auto &x : C)
x = mul(x, c);
return C;
}