論文鴿在群里說了一下這個東西,我也實現了一下,發現效果還不錯。
由於這個 exp 的 \(O(n\log n)\) 算法非常的慢,所以我們一般采用 \(O(n\log^2 n)\) 的分治 FFT 來求解。
普通的分治 FFT 已經可以與論文鴿的 \(O(n\log n)\) exp 五五開了,但是有沒有更快的方法呢?
注意,這個優化只能在 cdq FFT 的時候采用,也就是說不能優化 n 個一次多項式的卷積之類的問題。
\(O(n\log^2n)\)
我們先來回憶一下普通的分治 FFT 是如何做的。
我們假設 \(F(x) = e^{G(X)}\),這里我們知道 \(G(x)\),我們要求解 \(F(x)\)。
兩側求導,得 \(F^{'}(x) = e^{G(X)} \times G^{'}(X) = F(X) \times G^{'}(X)\)
也就是說我們是對這個式子進行求解:\(F^{'}(x) = F(X) \times G^{'}(X)\)
我們采用 \(solve(l, r)\) 表示求解 \(F(X)\) 的第 \(l\) 項到第 \(r\) 項。
取區間中點 \(mid\)。
先調用 \(solve(l, mid)\) 來求解出前半部分。
再計算左側對右側的貢獻。
再調用 \(solve(mid + 1, r)\) 來求解出后半部分。
這就是普通的分治 FFT。
\(O(\frac{n\log^2n}{\log \log n})\)
首先分治 FFT 是一個樹狀結構,我們往往可以嘗試增加一層往下的分支數來優化深度。
我們設分支數為 \(B\)。
如果直接分治 FFT,那么需要計算每個兒子對后面兒子的貢獻,每次計算需要一個長度為 \(O(n / B)\) 的卷積(\(n\)為目前分治區間長度)。
也就是說時間復雜度 \(T(n) = B \times T(\frac{n}{B}) + B \times n \times \log {\frac{n}{b}}\),大力求解得 \(B=2\) 時最優。
我是不是在玩你。
好的我們繼續。
我們真的是每一對兒子都要用一次卷積來計算貢獻嗎?
我們可以先求出這個兒子的點值,考慮它對后面兒子的貢獻,這個兒子的點值就不用重復計算了。
存儲前面兒子對它的貢獻的時候,你也可以直接存儲點值,最后一次 FFT 轉換即可,也不用多次計算了。
而卷上的是 \(G^{'}(x)\) 的一個區間,這個區間的點值也可以提前計算。
也就是說我們只需要 \(O(B)\) 次長度為 \(O(n / B)\) 的 FFT 了!
因此計算貢獻的部分復雜度變為了 \(O(B^2 \times \frac{n}{B} + B \times \frac{n}{b} \times \log {\frac{n}{b}})\),
即 \(O(Bn + n \times \log {\frac{n}{b}})\)。
時間復雜度 \(T(n) = B \times T(\frac{n}{B}) + Bn + n \times \log {\frac{n}{b}}\)。不錯,求解一下。
發現 \(B = O(\log n)\) 的時候最優秀,時間復雜度為\(O(\frac{n\log^2n}{\log \log n})\)。
事實上由於計算貢獻的時候非常的****,可以使用 avx2 進行優化,親測一定的常數優化之后進行 \(4 \times 10^6\) 的 exp 只需要 1.5s。
當然其他類似的 cdq FFT 也可以這樣進行優化,祝大家早日吊打 \(O(n\log n)\)。
貼代碼:
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")
#include<bits/stdc++.h>
#define rep(i, l, r) for(int i = (l), i##end = (r);i <= i##end;++i)
const int maxn = 1 << 19 | 1;
typedef long long ll;
typedef unsigned long long u64, ull;
const int mod = 998244353;
struct istream {
static const int size = 1 << 21;
char buf[size], *vin;
inline istream() {
fread(buf,1,size,stdin);
vin = buf - 1;
}
inline istream& operator >> (int & x) {
for(x = *++vin & 15;isdigit(*++vin);) x = x * 10 + (*vin & 15);
return * this;
}
} cin;
struct ostream {
static const int size = 1 << 21;
char buf[size], *vout;
unsigned map[10000];
inline ostream() {
for(int i = 0;i < 10000;++i) {
int p = i;
map[i] = p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
}
vout = buf + size;
}
inline ~ ostream()
{ fwrite(vout,1,buf + size - vout,stdout); }
inline ostream& operator << (int x) {
for(;x > 10000;x /= 10000) *--(unsigned*&)vout = map[x % 10000];
do *--vout = x % 10 + 48; while(x /= 10);
return * this;
}
inline ostream& operator << (char x) {
*--vout = x;
return * this;
}
} cout;
inline ll pow(ll a,int b,ll ans = 1){ for(;b;b >>= 1, a = a * a % mod) if(b & 1) ans = ans * a % mod; return ans; }
inline ll inverse(ll x){ return pow(x, mod - 2); }
int wn[1 << 13], rev[1 << 14], inv[maxn], lim, invlim;
inline void init_(int n) {
int N = 1; for(;N < n;) N <<= 1;
for(int i = 1;i < N;i <<= 1) {
const int w = pow(3, mod / i / 2); wn[i] = 1;
for(int j = 1;j < i;++j) wn[i + j] = (ll) wn[i + j - 1] * w % mod;
}
for(int i = 1;i <= N;i <<= 1) {
for(int j = 1;j < i;++j) rev[i + j] = rev[i + (j >> 1)] >> 1 | j % 2 * i / 2;
}
}
inline void init(int len) {
lim = len; invlim = mod - (mod - 1) / lim;
}
inline void reduce(int & x) {
x += x >> 31 & mod;
}
static u64 t[1 << 13];
inline void fft(int * a,int type) {
for(int i = 0;i < lim;++i) t[i] = a[rev[i + lim]];
#define trans(i, j, k) \
{ \
const u64 x = wn[i + k] * t[i + j + k] % mod; \
t[i + j + k] = t[j + k] + mod - x, t[j + k] += x; \
}
for(int i = 1;i < lim;i <<= 1) {
if(i == 1) {
for(int j = 0;j < lim;j += 8) {
trans(1, j, 0);
trans(1, j + 2, 0);
trans(1, j + 4, 0);
trans(1, j + 6, 0);
}
} else if(i == 2) {
for(int j = 0;j < lim;j += 8) {
trans(2, j, 0);
trans(2, j, 1);
trans(2, j + 4, 0);
trans(2, j + 4, 1);
}
} else {
for(int j = 0;j < lim;j += i + i) for(int k = 0;k < i;k += 4) {
trans(i, j, k + 0);
trans(i, j, k + 1);
trans(i, j, k + 2);
trans(i, j, k + 3);
}
}
}
if(type == 1) {
for(int i = 0;i < lim;++i) a[i] = t[i] % mod;
}
if(type == 0) {
a[0] = t[0] * invlim % mod;
for(int i = 1;i < lim;++i) a[i] = t[lim - i] * invlim % mod;
}
}
inline void fill(int * a, const int * b, int len) {
memcpy(a, b, len << 2), memset(a + len, 0, lim - len << 2);
}
typedef std::function<int(int, int*)> fc;
struct solver {
static const int C = 128;
static const int B = 64;
int n, N;
int rem[maxn], g[maxn], * MM;
int M[B][(maxn << 1) / B];
u64 g0[maxn << 2];
inline void Init(int len, int * multi) {
MM = multi;
for(n = len, N = 1;N < len;N <<= 1);
for(int mid = (N + N) / B;mid > 1;mid /= B) {
init(mid);
for(int j = 0;j + 1 < B;++j) {
if(j * mid / 2 < n) {
for(int i = 0;i < mid;++i) M[j][mid + i] = MM[i + j * mid / 2];
fft(M[j] + mid, 1);
}
}
}
}
inline void solve(int l, int r, u64 * g0, const fc & xxx) {
if(r - l < C) {
for(int i = l;i < r;++i) {
int j = l;
u64 x = rem[i];
#define T(o) (u64) g[j + o] * MM[i - j - o]
for(;j + 15 < i;j += 16) {
x = (x + T(0) + T(1) + T(2) + T(3) + T(4) + T(5) + T(6) + T(7) +
T(8) + T(9) + T(10) + T(11) + T(12) + T(13) + T(14) + T(15)) % mod;
}
if(j + 7 < i) x += T(0) + T(1) + T(2) + T(3) + T(4) + T(5) + T(6) + T(7), j += 8;
if(j + 3 < i) x += T(0) + T(1) + T(2) + T(3), j += 4;
if(j + 1 < i) x += T(0) + T(1), j += 2;
if(j < i) x += T(0);
#undef T
rem[i] = x % mod;
g[i] = xxx(i, rem + i);
}
return ;
}
const int DT = (r - l) / B;
if(l) memset(g0, 0, r - l << 4);
int end = 0;
for(;end < B && l + end * DT < n;++end);
for(int i = 0;i < end;++i) {
int L = l + i * DT, R = L + DT;
if(i) {
static int T[maxn];
init(DT + DT);
for(int j = 0;j < lim;++j) T[j] = g0[2 * i * DT + j] % mod;
fft(T, 2);
for(int j = L;j < R;++j) rem[j] = (rem[j] + (ll) invlim * t[lim - j + L - DT]) % mod;
}
solve(L, R, g0 + (r - l << 1), xxx);
if(i != end - 1) {
init(DT + DT);
static int b[maxn];
fill(b, g + L, R - L), fft(b, 1);
for(int j = i + 1;j < end;++j) {
ull * g1 = g0 + lim * j;
if(i == B / 2) {
for(int k = 0;k < lim;++k) {
g1[k] = (g1[k] + (ll) b[k] * M[j - i - 1][lim + k]) % mod;
}
} else {
for(int k = 0;k < lim;++k) {
g1[k] += (ll) b[k] * M[j - i - 1][lim + k];
}
}
}
}
}
}
inline void solve(fc x) { solve(0, N, g0, x); }
};
int n, a[maxn], b[maxn];
int main() {
static solver ln, exp;
cin >> n;
for(int i = 0;i < n;++i) {
cin >> a[i]; if(i) a[i] = mod - a[i];
b[i] = (ll) a[i] * i % mod;
}
inv[1] = 1;
for(int i = 2;i < n;++i) {
inv[i] = ll(mod - mod / i) * inv[mod % i] % mod;
}
init_((n + n) / solver::B + 1);
ln.Init(n, a);
ln.solve([](int pos, int * now) { return reduce(*now -= b[pos + 1]), *now; });
for(int i = 1;i < n;++i) {
b[i] = (ll) ln.g[i - 1] * inv[2] % mod;
}
exp.Init(n, b);
exp.solve([](int pos, int * now) { return int(pos == 0 ? 1 : (ll) *now * inv[pos] % mod); });
for(int i = n - 1;i >= 0;--i) {
cout << ' ' << exp.g[i];
}
}