O(nlog^2/loglogn)的cdq FFT


論文鴿在群里說了一下這個東西,我也實現了一下,發現效果還不錯。

由於這個 exp 的 \(O(n\log n)\) 算法非常的慢,所以我們一般采用 \(O(n\log^2 n)\) 的分治 FFT 來求解。

普通的分治 FFT 已經可以與論文鴿的 \(O(n\log n)\) exp 五五開了,但是有沒有更快的方法呢?

注意,這個優化只能在 cdq FFT 的時候采用,也就是說不能優化 n 個一次多項式的卷積之類的問題。

\(O(n\log^2n)\)

我們先來回憶一下普通的分治 FFT 是如何做的。

我們假設 \(F(x) = e^{G(X)}\),這里我們知道 \(G(x)\),我們要求解 \(F(x)\)

兩側求導,得 \(F^{'}(x) = e^{G(X)} \times G^{'}(X) = F(X) \times G^{'}(X)\)

也就是說我們是對這個式子進行求解:\(F^{'}(x) = F(X) \times G^{'}(X)\)

我們采用 \(solve(l, r)\) 表示求解 \(F(X)\) 的第 \(l\) 項到第 \(r\) 項。

取區間中點 \(mid\)

先調用 \(solve(l, mid)\) 來求解出前半部分。

再計算左側對右側的貢獻。

再調用 \(solve(mid + 1, r)\) 來求解出后半部分。

這就是普通的分治 FFT。

\(O(\frac{n\log^2n}{\log \log n})\)

首先分治 FFT 是一個樹狀結構,我們往往可以嘗試增加一層往下的分支數來優化深度。

我們設分支數為 \(B\)

如果直接分治 FFT,那么需要計算每個兒子對后面兒子的貢獻,每次計算需要一個長度為 \(O(n / B)\) 的卷積(\(n\)為目前分治區間長度)。

也就是說時間復雜度 \(T(n) = B \times T(\frac{n}{B}) + B \times n \times \log {\frac{n}{b}}\),大力求解得 \(B=2\) 時最優。

我是不是在玩你。

好的我們繼續。

我們真的是每一對兒子都要用一次卷積來計算貢獻嗎?

我們可以先求出這個兒子的點值,考慮它對后面兒子的貢獻,這個兒子的點值就不用重復計算了。

存儲前面兒子對它的貢獻的時候,你也可以直接存儲點值,最后一次 FFT 轉換即可,也不用多次計算了。

而卷上的是 \(G^{'}(x)\) 的一個區間,這個區間的點值也可以提前計算。

也就是說我們只需要 \(O(B)\) 次長度為 \(O(n / B)\) 的 FFT 了!

因此計算貢獻的部分復雜度變為了 \(O(B^2 \times \frac{n}{B} + B \times \frac{n}{b} \times \log {\frac{n}{b}})\)

\(O(Bn + n \times \log {\frac{n}{b}})\)

時間復雜度 \(T(n) = B \times T(\frac{n}{B}) + Bn + n \times \log {\frac{n}{b}}\)。不錯,求解一下。

發現 \(B = O(\log n)\) 的時候最優秀,時間復雜度為\(O(\frac{n\log^2n}{\log \log n})\)

事實上由於計算貢獻的時候非常的****,可以使用 avx2 進行優化,親測一定的常數優化之后進行 \(4 \times 10^6\) 的 exp 只需要 1.5s。

當然其他類似的 cdq FFT 也可以這樣進行優化,祝大家早日吊打 \(O(n\log n)\)

貼代碼:

#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")
#include<bits/stdc++.h>
#define rep(i, l, r) for(int i = (l), i##end = (r);i <= i##end;++i)
const int maxn = 1 << 19 | 1;
typedef long long ll;
typedef unsigned long long u64, ull;
const int mod = 998244353;
struct istream {
	static const int size = 1 << 21;
	char buf[size], *vin;
	inline istream() {
		fread(buf,1,size,stdin);
		vin = buf - 1;
	}
	inline istream& operator >> (int & x) {
		for(x = *++vin & 15;isdigit(*++vin);) x = x * 10 + (*vin & 15);
		return * this;
	}
} cin;
struct ostream {
	static const int size = 1 << 21;
	char buf[size], *vout;
	unsigned map[10000];
	inline ostream() {
		for(int i = 0;i < 10000;++i) {
			int p = i;
			map[i] = p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
		}
		vout = buf + size;
	}
	inline ~ ostream()
	{ fwrite(vout,1,buf + size - vout,stdout); }
	inline ostream& operator << (int x) {
		for(;x > 10000;x /= 10000) *--(unsigned*&)vout = map[x % 10000];
		do *--vout = x % 10 + 48; while(x /= 10);
		return * this;
	}
	inline ostream& operator << (char x) {
		*--vout = x;
		return * this;
	}
} cout;
inline ll pow(ll a,int b,ll ans = 1){ for(;b;b >>= 1, a = a * a % mod) if(b & 1) ans = ans * a % mod; return ans; }
inline ll inverse(ll x){ return pow(x, mod - 2); }
int wn[1 << 13], rev[1 << 14], inv[maxn], lim, invlim;
inline void init_(int n) {
	int N = 1; for(;N < n;) N <<= 1;
	for(int i = 1;i < N;i <<= 1) {
		const int w = pow(3, mod / i / 2); wn[i] = 1;
		for(int j = 1;j < i;++j) wn[i + j] = (ll) wn[i + j - 1] * w % mod;
	}
	for(int i = 1;i <= N;i <<= 1) {
		for(int j = 1;j < i;++j) rev[i + j] = rev[i + (j >> 1)] >> 1 | j % 2 * i / 2;
	}
}
inline void init(int len) {
	lim = len; invlim = mod - (mod - 1) / lim;
}
inline void reduce(int & x) {
	x += x >> 31 & mod;
}
static u64 t[1 << 13];
inline void fft(int * a,int type) {
	for(int i = 0;i < lim;++i) t[i] = a[rev[i + lim]];
#define trans(i, j, k) \
	{ \
		const u64 x = wn[i + k] * t[i + j + k] % mod; \
		t[i + j + k] = t[j + k] + mod - x, t[j + k] += x; \
	}
	for(int i = 1;i < lim;i <<= 1) {
		if(i == 1) {
			for(int j = 0;j < lim;j += 8) {
				trans(1, j, 0);
				trans(1, j + 2, 0);
				trans(1, j + 4, 0);
				trans(1, j + 6, 0);
			}
		} else if(i == 2) {
			for(int j = 0;j < lim;j += 8) {
				trans(2, j, 0);
				trans(2, j, 1);
				trans(2, j + 4, 0);
				trans(2, j + 4, 1);
			}
		} else {
			for(int j = 0;j < lim;j += i + i) for(int k = 0;k < i;k += 4) {
				trans(i, j, k + 0);
				trans(i, j, k + 1);
				trans(i, j, k + 2);
				trans(i, j, k + 3);
			}
		}
	}
	if(type == 1) {
		for(int i = 0;i < lim;++i) a[i] = t[i] % mod;
	}
	if(type == 0) {
		a[0] = t[0] * invlim % mod;
		for(int i = 1;i < lim;++i) a[i] = t[lim - i] * invlim % mod;
	}
}
inline void fill(int * a, const int * b, int len) {
	memcpy(a, b, len << 2), memset(a + len, 0, lim - len << 2);
}
typedef std::function<int(int, int*)> fc;
struct solver {
	static const int C = 128;
	static const int B = 64;
	int n, N;
	int rem[maxn], g[maxn], * MM;

	int M[B][(maxn << 1) / B];
	u64 g0[maxn << 2];
	inline void Init(int len, int * multi) {
		MM = multi;
		for(n = len, N = 1;N < len;N <<= 1);
		for(int mid = (N + N) / B;mid > 1;mid /= B) {
			init(mid); 
			for(int j = 0;j + 1 < B;++j) {
				if(j * mid / 2 < n) {
					for(int i = 0;i < mid;++i) M[j][mid + i] = MM[i + j * mid / 2];
					fft(M[j] + mid, 1);
				}
			}
		}
	}
	inline void solve(int l, int r, u64 * g0, const fc & xxx) {
		if(r - l < C) {
			for(int i = l;i < r;++i) {
				int j = l;
				u64 x = rem[i];
#define T(o) (u64) g[j + o] * MM[i - j - o]
				for(;j + 15 < i;j += 16) {
					x = (x + T(0) + T(1) + T(2) + T(3) + T(4) + T(5) + T(6) + T(7) + 
						 	 T(8) + T(9) + T(10) + T(11) + T(12) + T(13) + T(14) + T(15)) % mod;
				}
				if(j + 7 < i) x += T(0) + T(1) + T(2) + T(3) + T(4) + T(5) + T(6) + T(7), j += 8;
				if(j + 3 < i) x += T(0) + T(1) + T(2) + T(3), j += 4;
				if(j + 1 < i) x += T(0) + T(1), j += 2;
				if(j < i) x += T(0);
#undef T
				rem[i] = x % mod;
				g[i] = xxx(i, rem + i);
			}
			return ;
		}
		const int DT = (r - l) / B;
		if(l) memset(g0, 0, r - l << 4);
		int end = 0;
		for(;end < B && l + end * DT < n;++end);
		for(int i = 0;i < end;++i) {
			int L = l + i * DT, R = L + DT;
			if(i) {
				static int T[maxn];
				init(DT + DT);
				for(int j = 0;j < lim;++j) T[j] = g0[2 * i * DT + j] % mod;
				fft(T, 2);
				for(int j = L;j < R;++j) rem[j] = (rem[j] + (ll) invlim * t[lim - j + L - DT]) % mod;
			}
			solve(L, R, g0 + (r - l << 1), xxx);
			if(i != end - 1) {
				init(DT + DT);
				static int b[maxn];
				fill(b, g + L, R - L), fft(b, 1);
				for(int j = i + 1;j < end;++j) {
					ull * g1 = g0 + lim * j;
					if(i == B / 2) {
						for(int k = 0;k < lim;++k) {
							g1[k] = (g1[k] + (ll) b[k] * M[j - i - 1][lim + k]) % mod;
						}
					} else {
						for(int k = 0;k < lim;++k) {
							g1[k] += (ll) b[k] * M[j - i - 1][lim + k];
						}
					}
				}
			}
		}
	}
	inline void solve(fc x) { solve(0, N, g0, x); }
};
int n, a[maxn], b[maxn];
int main() {
	static solver ln, exp;
	cin >> n;
	for(int i = 0;i < n;++i) {
		cin >> a[i]; if(i) a[i] = mod - a[i];
		b[i] = (ll) a[i] * i % mod;
	}
	inv[1] = 1;
	for(int i = 2;i < n;++i) {
		inv[i] = ll(mod - mod / i) * inv[mod % i] % mod;
	}
	init_((n + n) / solver::B + 1);
	ln.Init(n, a);
	ln.solve([](int pos, int * now) { return reduce(*now -= b[pos + 1]), *now; });
	for(int i = 1;i < n;++i) {
		b[i] = (ll) ln.g[i - 1] * inv[2] % mod;
	}
	exp.Init(n, b);
	exp.solve([](int pos, int * now) { return int(pos == 0 ? 1 : (ll) *now * inv[pos] % mod); });
	for(int i = n - 1;i >= 0;--i) {
		cout << ' ' << exp.g[i];
	}
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM