FFT 的优化和任意模数 FFT
1. 前言和前置技能
这篇主要讲卡常如何卡到,以及任意模数的 FFT 。uoj
榜第二页
虽然当时最开始学的是 \(FFT\),然而似乎没什么人用,因为多项式全家桶里用的都是 \(NTT\)。而且 \(FFT\) 跑的也并不是很快,所以大概没什么人学吧。
然而我们会发现uoj
榜前面的都是 \(FFT\),因为毛爷爷(毛啸)的论文里介绍了很多卡(qi)常(ji)技(yin)巧(qiao)。实际效果非常显著,速度是普通 \(FFT\) 的 \(2\) 到 \(5\) 倍不等。
\(FFT\) 有个巨大的先天优势:运算都是在复数域上的。虽然会出现各种精度误差以及常数巨大,但是 \(NTT\) 要是想借此来嘲笑 \(FFT\) 也不过是五十步笑百步罢了。
\(NTT\) 最大的优势在于运算都是在模意义下进行的,这就可以实现各种模意义下的操作。然而 \(NTT\) 最大的劣势也在于模意义,于是当模数不是 \(NTT\) 模数的时候就会当场去世。说到底,模意义下的运算总是束手束脚的,为什么我们不走向连续的更加广阔的复数域呢?
不知道大家有没有发现一个细节,\(FFT\) 在 \(DFT\) 之前和 \(IDFT\) 之后各系数的虚部都是 \(0\) 。这就造成了巨大的浪费。我们的 \(FFT\) 本身就能对复数域进行 \(DFT\) ,为什么我们不能利用一下这个神奇的性质呢?
你需要的技能有:
- \(FFT\)
- 一些关于复数的小知识
回顾一下 \(DFT\) 干了什么:
对于一个多项式 \(f(x) = \sum\limits_{j=0}^na_jx^j\) ,设 \(N\) 表示延长至 \(2\) 的整次幂的长度 ,\(X\) 表示 \(DFT\) 之后的结果, \(w_N\) 表示 \(N\) 次单位复根,则有:
$\bar z $ 表示 $ z $ 的共轭。复数 \(a + bi\) 的共轭是 \(a - bi\) 。\(\overline {z_iz_j} = \bar z_i \bar z_j\)
然后以下的边界可能会出一些问题,我们默认 \(N - 0\) 和 \(0\) 是同一个位置,但是大于等于 \(N\) 的位置系数都是 \(0\)。
2. 合并
\(DFT\) 其实是个挺对称的过程,因为 \(w_N^k = -w_N^{\frac N 2+k}\) 。我们来比较一下 \(X_i\) 和 \(X_{N-i}\) :
好像。。。。十分有理有据令人信服。我们求出 \(X_i\) 就能知道 \(X_{N-i}\) 了。那我们用 \(N\) 的长度做难道不是很浪费吗?
为什么 \(X_j = \overline{ X_{N-j}}\) 呢?根本原因是我们的多项式系数没有虚部。我们换一个有虚部的试一下。设
然后 \(DFT\) 一下, 得到
如何是好?
然鹅我们对比一下 \(\overline {X_{N-j}}\) 和 \(\overline{X_j}\) ,我们会发现
然后我们再用 \(X_j\) 加减 \(\overline{X_{N-j}}\),会得到:
而且对于两个DFT \(X_j, Y_j\) ,我们可以通过同样的方法 \(a_j = \frac 1 N \sum\limits_{k=0}^{N-1}(X_k + iY_k)w_N^{-jk}\)来得到 \(X_k, Y_k\) 的 IDFT 。
3. 分裂
一般来说关于 \(FFT\) 的常数指的是 \(DFT/IDFT\) 的次数。
我们先用上面的技巧优化多项式乘法:
现在我们要将两个多项式 \(f(x) = \sum\limits_{j=0}^na_jx^j\) 和 \(g(x) = \sum\limits_{j=0}^mb_jx^j\) 卷起来得到 \(h(x) = \sum\limits_{j=0}^mc_jx^j\) 。正常的做法,我们需要 \(3\) 次两倍长度的 \(FFT\)。
我们先把它分奇偶拆成两个多项式,然后按照刚刚的合并方法,得到
记 \(X, Y, Z\) 为 \(f, g, h\) 的 \(DFT\) ,\(N\) 为单倍长度的 \(FFT\) 长度。
我们发现 \(Z_j\) 和 \(X_jY_j\) 只有一项不同,我们减一下:
(注意 \((18)\) 式的前半部分 \(k\) 不能取到 \(0\) ,所以我们换一下求和指标即可)。
然后又因为
继续化简得到
于是乎
这样做时空常数都是普通写法的一半。
4. 任意模数FFT
黑科技来了。
在任意模数下,系数有可能是 \(1e9\) 级别的,而double
的精度在 \(1e14\) 级别,于是直接 \(FFT\) 我们就炸精了。虽然我们有三模数 \(NTT\),但是那玩意,嗯,我们今天不讲。
本来想说那玩意慢死了,然后觉得 FFT 笑 NTT 也不过是五十步笑百步
但是需要大量卷积的题这个玩意就是快啊
我们将某个系数 \(a_j\) 表示成 \({a_{0j}}M+a_{1j}\),然后拆成两个系数小于 \(M\) 的多项式 \(A_0(x)\) 和 \(A_1(x)\) ,于是我们就能直接 FFT 了。一般 \(M\) 取 \(32768\) 比较方便。
这样做需要 \(8\) 次 \(FFT\)。比起 \(NTT\) 的 \(9\) 遍并没有快到哪儿去。(话说 \(NTT\) 也能这么拆吧?)
但是上面不是讲了优化技巧了吗?我们的虚部还是空的。
然后我们就能大力拆式子啦!
我们最后要求的式子是 \(A_1B_1 + M(A_1B_0+A_0B_1)+M^2A_0B_0\)。
我们令多项式 \(F(x), G(x)\) 的各项系数为
\(X, Y\) 表示 \(F(x), G(x)\) 的 \(DFT\),则
我们此时令 \(\mathbf {A, B, C, D}\) 来表示 \(A_0, A_1, B_0, B_1\) 的 \(DFT\),于是我们有
除 \(i\) 怎么办呢?复数除法超级麻烦的。但是我们注意到
于是变成乘法就好了。
然后接下来,我们再用类似的技巧, 构造两个新的 \(DFT\) ,\(\mathbf {P,Q}\)。
然后 \(IDFT\) 回去,最后的答案就是 \(\Re (\mathbf P_j) + M(\Im(\mathbf P_j)+\Re(\mathbf Q_j))+M^2\Im(\mathbf Q_j)\)。
然而实际应用的时候我们还能再化简一步。
同理
于是我们的常数更小了。
尽管如此,它还是跑不过普通的 \(NTT\),目前大概是我的 \(NTT\) 的时间的 \(1.5\) 倍吧,因为做了 \(4\) 次两倍长度的 \(FFT\) 。
总体来说常数因子的影响还是很明显的,最上面那个常数是正常写法的二分之一,时间也大概在二分之一左右,而下面的写法常数是三分之四,时间也是三分之四,总之 \(FFT\) 次数对代码的运行时间有着重大的影响。
注意这个算法对精度要求很高,很容易炸精度。
下面的代码里有一种处理单位根的方式,能保证每个数最多被乘 \(\log n\) 次,保证精度的时候常数也小。
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cctype>
#include <algorithm>
const int LEN = 1 << 20 | 1;
char bufin[LEN];
char bufout[LEN];
char *Rd = bufin, *Wt = bufout;
#define getchar() (*Rd++)
#define putchar(x) (*Wt++ = x)
inline int read() {
register int ret, cc;
while (!isdigit(cc = getchar())){}ret = cc-48;
while ( isdigit(cc = getchar())) ret = cc-48+ret*10;
return ret;
}
inline void write(int x, char ch = '\n') {
register int stk[20], tp;
stk[tp = !x] = 0;
while (x) stk[++tp] = x % 10, x /= 10;
while (tp) putchar(stk[tp--] + '0');
putchar(ch);
}
struct Complex {
double x, y;
Complex(double x = 0, double y = 0)
:x(x), y(y) { }
inline Complex operator + (const Complex& rhs) const { return Complex(x + rhs.x, y + rhs.y); }
inline Complex operator - (const Complex& rhs) const { return Complex(x - rhs.x, y - rhs.y); }
inline Complex operator * (const Complex& rhs) const { return Complex(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x); }
inline Complex operator - () const { return Complex(-x, -y); }
inline Complex operator ! () const { return Complex( x, -y); }
void print() {
printf("(%f, %f)\n", x, y);
}
};
const int MAXN = 131080;
const double PI = acos(-1.0);
inline void FFT(Complex*, int, int);
inline int getrev(int);
int rev[MAXN];
int N, M;
Complex A[MAXN];
Complex B[MAXN];
Complex C[MAXN];
Complex W[MAXN];
int tmp[MAXN];
int main() {
#ifndef ONLINE_JUDGE
freopen("test.in", "r", stdin);
#endif
fread(bufin, 1, LEN, stdin);
N = read() + 1, M = read() + 1;
for (int i = 0; i < N; ++i) (i & 1 ? A[i >> 1].y : A[i >> 1].x) = read();
for (int i = 0; i < M; ++i) (i & 1 ? B[i >> 1].y : B[i >> 1].x) = read();
int len = N + M - 1, bln = getrev((len+1)>>1);
FFT(A, bln, 1), FFT(B, bln, 1);
Complex w(1, 0);
for (int i = 0, j; i < bln; ++i, w = w * W[1]) {
j = (bln-i)&(bln-1);
C[i] = A[i]*B[i]-((A[i]-!A[j])*(B[i]-!B[j])*(w+1))*0.25;
}
FFT(C, bln, -1);
for (int i = 0; i < len; ++i) write(((i & 1) ? C[i >> 1].y : C[i >> 1].x) + 0.5, ' ');
putchar('\n');
fwrite(bufout, 1, Wt - bufout, stdout);
}
inline int getrev(int n) {
int bln = 1, bct = 0;
while (bln < n) bln <<= 1, bct++;
for (int i = 0; i < bln; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bct - 1));
return bln;
}
inline void FFT(Complex* a, int n, int opt) {
for (int i = 0; i < n; ++i)
if (i < rev[i]) std::swap(a[i], a[rev[i]]);
W[0] = 1;
for (int i = 1; i < n; i <<= 1) {
Complex wn(cos(PI / i), opt * sin(PI / i));
for (int k = i - 2; k >= 0; k -= 2)
W[k + 1] = (W[k] = W[k >> 1]) * wn;
for (int j = 0, p = (i << 1); j < n; j += p) {
Complex *x = a + j, *y = a + j + i, *w = W;
for (int k = 0; k < i; ++k, ++x, ++y, ++w) {
Complex tmp = *w * *y;
*y = *x - tmp, *x = *x + tmp;
}
}
}
if (opt == -1) for (int i = 0; i < n; ++i) a[i].x /= n, a[i].y /= n;
}
外加一个全家桶。
#include <bits/stdc++.h>
inline int read() {
int ret, cc, sign = 1;
while (!isdigit(cc = getchar()))
sign = cc == '-' ? -1 : sign;
ret = cc - 48;
while (isdigit(cc = getchar()))
ret = cc - 48 + ret * 10;
return ret * sign;
}
const int MOD = 998244353;
const int G = 3;
const int MAXN = 600010;
typedef std::vector<int> Poly;
typedef long long i64;
inline int add(int a, int b) { return (a += b) >= MOD ? a - MOD : a; }
inline int sub(int a, int b) { return (a -= b) < 0 ? a + MOD : a; }
inline int mul(int a, int b) { return 1ll * a * b % MOD; }
inline int qpow(int a, int p) {
int ret = 1;
for (p += (p < 0) * (MOD - 1); p; p >>= 1, a = mul(a, a))
if (p & 1) ret = mul(ret, a);
return ret;
}
inline Poly operator + (const Poly&, const Poly&);
inline Poly operator - (const Poly&, const Poly&);
inline Poly operator * (const Poly&, const Poly&);
inline Poly Inverse(const Poly&);
inline Poly Integrate(const Poly&);
inline Poly Derivate(const Poly&);
inline Poly Ln(const Poly&);
inline Poly Exp(const Poly&);
inline Poly Pow(const Poly&, int);
inline Poly Sqrt(const Poly&);
inline void Read(Poly&);
inline void Print(const Poly&);
int main() {
#ifdef ARK
freopen("test.in", "r", stdin);
#endif
Poly A(read()), One(1, 1);
int K = read();
Read(A);
Poly ans = Derivate(Pow(One + Ln(One +
Inverse(Exp(Integrate(Inverse(Sqrt(A)))))), K));
ans[ans.size() - 1] = 0;
Print(ans);
}
inline void Read(Poly& A) {
for (auto &x : A)
x = read();
}
inline void Print(const Poly& A) {
for (auto x : A)
printf("%d ", x);
puts("");
}
int rev[MAXN];
int W[MAXN];
inline int getrev(int n) {
int len = 1, cnt = 0;
while (len < n) len <<= 1, ++cnt;
for (int i = 0; i < len; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
return len;
}
struct Complex {
double x, y;
Complex(double x = 0, double y = 0) : x(x), y(y) {}
inline Complex operator+(const Complex& rhs) const { return Complex(x + rhs.x, y + rhs.y); }
inline Complex operator-(const Complex& rhs) const { return Complex(x - rhs.x, y - rhs.y); }
inline Complex operator*(const Complex& rhs) const {
return Complex(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x);
}
};
Complex CW[MAXN];
const double Pi = acos(-1.0);
void FFT(Complex* a, int n, int opt) {
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
CW[0] = 1;
for (int i = 1; i < n; i <<= 1) {
Complex wn(cos(Pi / i), opt * sin(Pi / i));
for (int k = i - 2; k >= 0; k -= 2) CW[k + 1] = (CW[k] = CW[k >> 1]) * wn;
for (int j = 0, p = i << 1; j < n; j += p) {
for (int k = 0; k < i; ++k) {
Complex x = a[j + k], y = CW[k] * a[j + k + i];
a[j + k] = x + y, a[j + k + i] = x - y;
}
}
}
if (opt == -1)
for (int i = 0; i < n; ++i) a[i].x /= n, a[i].y /= n;
}
inline Poly Conv(const Poly& A, const Poly& B) {
static Complex a[MAXN], b[MAXN];
static Complex c[MAXN], d[MAXN];
int len = A.size() + B.size() - 1, bln = getrev(len);
for (size_t i = 0; i < A.size(); ++i) a[i] = Complex(A[i] & 32767, A[i] >> 15);
for (int i = A.size(); i < bln; ++i) a[i] = 0;
for (size_t i = 0; i < B.size(); ++i) b[i] = Complex(B[i] & 32767, B[i] >> 15);
for (int i = B.size(); i < bln; ++i) b[i] = 0;
FFT(a, bln, 1), FFT(b, bln, 1);
for (int i = 0, j; i < bln; ++i) {
j = (bln - i) & (bln - 1);
c[i] = Complex(a[i].x + a[j].x, a[i].y - a[j].y) * 0.5 * b[i];
d[i] = Complex(a[i].y + a[j].y, a[j].x - a[i].x) * 0.5 * b[i];
}
FFT(c, bln, -1), FFT(d, bln, -1);
Poly C(len);
for (int i = 0; i < len; ++i) {
i64 u = ((i64)(c[i].x + 0.5)) % MOD, v = ((i64)(c[i].y + 0.5)) % MOD;
i64 x = ((i64)(d[i].x + 0.5)) % MOD, y = ((i64)(d[i].y + 0.5)) % MOD;
C[i] = (u + ((v + x) << 15) % MOD + (y << 30) % MOD) % MOD;
}
return C;
}
inline void NTT(Poly& a, int n, int opt) {
a.resize(n);
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
for (int i = (W[0] = 1); i < n; i <<= 1) {
int wn = qpow(G, opt * (MOD - 1) / (i << 1));
for (int k = i - 2; k >= 0; k -= 2)
W[k + 1] = mul(W[k] = W[k >> 1], wn);
for (int j = 0, p = i << 1; j < n; j += p) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = mul(W[k], a[j + k + i]);
a[j + k] = add(x, y), a[j + k + i] = sub(x, y);
}
}
}
if (opt == -1)
for (int i = 0, r = qpow(n, MOD - 2); i < n; ++i)
a[i] = mul(a[i], r);
}
inline Poly operator + (const Poly& lhs, const Poly& rhs) {
Poly ret(std::max(lhs.size(), rhs.size()));
for (size_t i = 0; i < ret.size(); ++i)
ret[i] = add(i < lhs.size() ? lhs[i] : 0, i < rhs.size() ? rhs[i] : 0);
return ret;
}
inline Poly operator - (const Poly& lhs, const Poly& rhs) {
Poly ret(std::max(lhs.size(), rhs.size()));
for (size_t i = 0; i < ret.size(); ++i)
ret[i] = sub(i < lhs.size() ? lhs[i] : 0, i < rhs.size() ? rhs[i] : 0);
return ret;
}
//inline Poly operator * (const Poly& lhs, const Poly& rhs) {
// Poly A(lhs), B(rhs);
// int len = A.size() + B.size() - 1;
// int bln = getrev(len);
// NTT(A, bln, 1), NTT(B, bln, 1);
// for (int i = 0; i < bln; ++i)
// A[i] = mul(A[i], B[i]);
// NTT(A, bln, -1), A.resize(len);
// return A;
//}
inline Poly operator * (const Poly& lhs, const Poly& rhs) {
return Conv(lhs, rhs);
}
inline Poly Inverse(const Poly& A) {
Poly B(1, qpow(A[0], MOD - 2));
int n = A.size() << 1;
for (int i = 2; i < n; i <<= 1) {
Poly C(A);
C.resize(i);
//int len = getrev(i << 1);
//NTT(C, len, 1), NTT(B, len, 1);
C = Poly(1, 2) - B * C;
C.resize(i);
B = B * C;
B.resize(i);
//for (int j = 0; j < len; ++j)
// B[j] = mul(B[j], sub(2, mul(B[j], C[j])));
//NTT(B, len, -1), B.resize(i);
}
B.resize(A.size());
return B;
}
inline std::vector<int> getinv() {
std::vector<int> ret(MAXN);
ret[1] = 1;
for (int i = 2; i < MAXN; ++i)
ret[i] = mul(MOD - MOD / i, ret[MOD % i]);
return ret;
}
std::vector<int> Inv = getinv();
inline Poly Integrate(const Poly& A) {
Poly C(A.size() + 1);
for (size_t i = 1; i < C.size(); ++i)
C[i] = mul(Inv[i], A[i - 1]);
return C;
}
inline Poly Derivate(const Poly& A) {
Poly C(A.size() - 1);
for (size_t i = 0; i < C.size(); ++i)
C[i] = mul(i + 1, A[i + 1]);
return C;
}
inline Poly Ln(const Poly& A) {
Poly C = Integrate(Derivate(A) * Inverse(A));
C.resize(A.size());
return C;
}
inline Poly Exp(const Poly& A) {
Poly B(1, 1);
int n = A.size() << 1;
for (int i = 2; i < n; i <<= 1) {
Poly C(A);
C.resize(i), B.resize(i);
B = B * (Poly(1, 1) - Ln(B) + C);
}
B.resize(A.size());
return B;
}
inline Poly Pow(const Poly& A, int k) {
Poly C(Ln(A));
for (auto &x : C)
x = mul(x, k);
return Exp(C);
}
inline Poly Sqrt(const Poly& A) {
Poly C(A);
int c = A[0], ic = qpow(c, MOD - 2);
for (auto &x : C)
x = mul(x, ic);
c = sqrt(c), C = Pow(C, Inv[2]);
for (auto &x : C)
x = mul(x, c);
return C;
}