Description
給你 \(n\) 塊白木板,\(k\) 塊紅木板,分別有各自的長度 \(h_i\)。讓你用這些木板組成一段圍欄,要滿足:
- 只用一塊紅木板,且所有白木板的長度均嚴格小於紅木板長度;
- 紅木板左邊的白木板長度嚴格單調遞增;
- 紅木板右邊的白木板長度嚴格單調遞減
現在給出 \(q\) 組詢問,問周長為 \(x_i\) 的圍欄有多少種。
\(1\leq n,h_i,q\leq 3\cdot 10^5,4\leq x_i\leq 12\cdot 10^5,1\leq k\leq 5\)
Solution
如果我們選的紅木板長度為 \(L\),並且這段圍欄由 \(m\) 塊板子構成,顯然周長為 \(2\times (L+m)\)。
那么問題就轉化為了,求用 \(m-1\) 塊長度 \(<L\) 的白木板形成兩段長度嚴格單調遞增序列的方案數。
因為木板長度是離散的,我們可以考慮每種長度的木板的放法。
我們把所有長度 \(<L\) 的白木板取出。若某種長度的木板只有一塊。那么顯然,這塊木板可以放在紅木板的左邊或右邊(即任意一個序列中)。
對於某種長度有兩塊以上的時候,我們可以把他放在左邊、右邊或者兩邊都放。並且我們最多只會用 \(2\) 塊這樣的木板,所以多余的可以除去。
假設第一種情況(該長度的木板只有一塊)下的木板個數為 \(sa\)。顯然用這 \(sa\) 塊木板構成兩段序列總長度為 \(i\) 的方案數 \(a_i={sa \choose i}\times 2^i\)。
假設第二種情況下的木板個數為 \(sb\)(除去多余的木板)。用這 \(sb\) 塊木板構成兩段序列總長度為 \(i\) 的的方案數 \(b_i={sb \choose i}\)。
記兩種情況長度總和為 \(i\) 的方案數為 \(c_i\),那么容易發現我們要求的就是
\[ c_{m-1}=\sum_{i=0}^{m-1}a_ib_{m-1-i} \]
這是一個卷積式,那么我們設
\[ \begin{aligned} A(x)&=\sum_i a_i x^i\\ B(x)&=\sum_i b_i x^i\\ C(x)&=A(x)\otimes B(x)\\ &=\sum_i c_i x^i \end{aligned} \]
那么我們就可以用 \(\text{NTT}\) 來求出選該紅木板時,對應選不同白木板個數的方案數了。
因此我們可以枚舉每個紅木板,做一次 \(\text{NTT}\) 累計到答案中,\(O(1)\) 回答詢問。
總復雜度 \(O(k\times n\log n +q)\)。
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 12e5+5, yzh = 998244353;
int n, k, q, cnt[N], x, ans[N], fac[N], ifac[N];
int A[N], B[N], a, b, L, R[N];
int quick_pow(int a, int b) {
int ans = 1;
while (b) {
if (b&1) ans = 1ll*ans*a%yzh;
b >>= 1, a = 1ll*a*a%yzh;
}
return ans;
}
int C(int n, int m) {return 1ll*fac[n]*ifac[m]%yzh*ifac[n-m]%yzh; }
void NTT(int *A, int o) {
for (int i = 0; i < n; i++) if (i < R[i]) swap(A[i], A[R[i]]);
for (int i = 1; i < n; i <<= 1) {
int gn = quick_pow(3, (yzh-1)/(i<<1)), x, y;
if (o == -1) gn = quick_pow(gn, yzh-2);
for (int j = 0; j < n; j += (i<<1)) {
int g = 1;
for (int k = 0; k < i; k++, g = 1ll*g*gn%yzh) {
x = A[j+k], y = 1ll*g*A[j+k+i]%yzh;
A[j+k] = (x+y)%yzh;
A[j+k+i] = (x-y)%yzh;
}
}
}
}
int main() {
scanf("%d%d", &n, &k);
fac[0] = ifac[0] = ifac[1] = 1;
for (int i = 2; i <= n; i++) ifac[i] = -1ll*yzh/i*ifac[yzh%i]%yzh;
for (int i = 1; i <= n; i++)
fac[i] = 1ll*i*fac[i-1]%yzh,
ifac[i] = 1ll*ifac[i-1]*ifac[i]%yzh;
for (int i = 1; i <= n; i++) scanf("%d", &x), cnt[x]++;
while (k--) {
scanf("%d", &x); a = b = 0;
for (int i = 1; i < x; i++)
if (cnt[i] >= 2) a += 2;
else if (cnt[i] == 1) b++;
memset(A, 0, sizeof(A));
memset(B, 0, sizeof(B));
for (int i = 0; i <= a; i++) A[i] = C(a, i);
for (int i = 0; i <= b; i++) B[i] = 1ll*C(b, i)*quick_pow(2, i)%yzh;
a += b; L = 0;
for (n = 1; n <= a; n <<= 1) ++L;
for (int i = 0; i < n; i++) R[i] = (R[i>>1]>>1)|((i&1)<<(L-1));
NTT(A, 1), NTT(B, 1);
for (int i = 0; i < n; i++) A[i] = 1ll*A[i]*B[i]%yzh;
NTT(A, -1);
int inv = quick_pow(n, yzh-2);
for (int i = 0; i <= a; i++) A[i] = 1ll*A[i]*inv%yzh;
for (int i = 0; i <= a; i++)
(ans[(x+1+i)<<1] += A[i]) %= yzh;
}
scanf("%d", &q);
while (q--) scanf("%d", &x), printf("%d\n", (ans[x]+yzh)%yzh);
return 0;
}