題目傳送門:LOJ #3045。
題意簡述
略。
題解
從高斯消元出發好像需要一些集合冪級數的知識,就不從這個角度思考了。
令 \(\displaystyle \dot p = \sum_{i = 1}^{n} p_i\)。
我們考慮一個操作序列 \(\{a_1, a_2, \ldots , a_k\}\),其中 \(1 \le a_j \le n\),就表示第 \(i\) 次按下了開關 \(a_j\)。
那么按 \(k\) 次后恰好得到這個序列的概率就是 \(\displaystyle \prod_{j = 1}^{k} (p_{a_j} / \dot p)\)。
那么我們考慮如果按下這個序列后恰好得到了目標狀態 \(s\):
當且僅當對於每個 \(i\)(\(1 \le i \le n\))均滿足按下開關 \(i\) 的次數的奇偶性恰好等於 \(s_i\)。
形式化地說,就是對於每個 \(i\) 有 \(\displaystyle \left( \sum_{j = 1}^{k} [a_j = i] \right) \bmod 2 = s_i\)。
那么我們對每個 \(i\) 分開考慮,對於 \(s_i = 0\) 的需要按偶數次,對於 \(s_i = 1\) 的需要按奇數次。
- 對於某個 \(s_i = 0\) 的 \(i\),我們給出這樣的數列:\(f_i = \{1, 0, {(p_i / \dot p)}^2, 0, {(p_i / \dot p)}^4, 0, {(p_i / \dot p)}^6, 0, \ldots \}\)。
- 對於某個 \(s_i = 1\) 的 \(i\),我們給出這樣的數列:\(f_i = \{0, p_i / \dot p, 0, {(p_i / \dot p)}^3, 0, {(p_i / \dot p)}^5, 0, {(p_i / \dot p)}^7, \ldots \}\)。
可以發現,把所有的 \(i\) 的數列全部二項卷積起來,就得到了一個新的數列 \(f\),這個數列滿足:
對於 \(f\) 的第 \(k\) 項 \(f_k\),就表示了當按 \(k\) 下開關時,恰好得到狀態 \(s\) 的概率。
因為是 二項卷積,所以我們把這個過程寫成 指數型概率生成函數 的形式:
定義 \(\hat F_i (x) = \mathbf{EGF} \left( { \left\{ [j \bmod 2 = s_i] {(p_i / \dot p)}^j \right\} }_{j = 0}^{\infty} \right)\),
也就是每個 \(i\) 對應的上述數列 \(f_i\) 的指數型生成函數,
寫做封閉形式,就是 \(\displaystyle \hat F_i (x) = \frac{e^{(p_i / \dot p) x} + {(-1)}^{s_i} e^{-(p_i / \dot p) x}}{2}\)。
所以最終得到的 \(f\) 的 EGF 就是 \(\displaystyle \hat F (x) = \prod_{i = 1}^{n} \frac{e^{(p_i / \dot p) x} + {(-1)}^{s_i} e^{-(p_i / \dot p) x}}{2}\)。
看起來非常的變態,但是還沒完!出什么問題了?
首先我們要明確:得到 \(f\) 能干啥?
發現 \(f\) 的性質是:\(f_k\) 表示按恰好 \(k\) 次開關得到狀態 \(s\) 的概率,那么根據期望的定義,答案就是 \(\displaystyle \sum_{i = 0}^{\infty} i f_i\)。
這是啥啊,就是 \(f\) 對應的 普通生成函數 \(\displaystyle F(x) = \sum_{i = 0}^{\infty} f_i x^i\),它在 \(1\) 處的導數,也就是 \(F'(1)\)。
(回顧形式冪級數求導,以及求值的定義)
但是 錯了,再觀察一下,題目要求的是 第一次 到達狀態 \(s\) 的期望步數,而不是現在這個樣子。
(因為可能不是第一次,而是此前已經經過很多次了。實際上如果直接求這個,甚至是不收斂的)
那么怎么辦呢?我們發現需要排除第一次到達 \(s\) 后,又經過若干步返回 \(s\) 的情況,也就是返回原狀態了。
由此,我們考慮求出數列 \(g\),其中 \(g_k\) 表示在 \(k\) 步后恰好返回原狀態的概率。
那么可以發現,如果令最終答案的數列為 \(h\),有 \(h \ast g = f\)(\(h\) 卷 \(g\) 等於 \(f\),是普通卷積不是二項卷積)。
而 \(g\) 應該如何求得呢?其實就是當全部 \(s_i = 0\) 時的 \(f\) 啦,因為是要返回原狀態嘛。
上面說了一堆理論上的東西,現在我們考慮如何實現。
首先發現求的時候是 EGF,但是算答案的時候是 OGF,這很怪。我們觀察一下形式看看能不能轉換。
對於 \(\hat F\),有形式 \(\displaystyle \hat F (x) = \prod_{i = 1}^{n} \frac{e^{(p_i / \dot p) x} + {(-1)}^{s_i} e^{-(p_i / \dot p) x}}{2}\)。
我們把每個形如 \(a_w e^{(w / \dot p)x}\) 的式子看作一項,可以發現最終 \(w\) 的取值在 \([-\dot p, \dot p]\)。
所以把 \(\hat F (x)\) 表示成 \(\displaystyle \sum_{w = -\dot p}^{\dot p} a_w e^{(w / \dot p) x}\) 的形式后,我們就有 \(\displaystyle f_k = \sum_{w = -\dot p}^{\dot p} a_w {(w / \dot p)}^k\)。
再把這個形式轉換成 OGF,得到 \(\displaystyle \mathbf{OGF} (f) = F(x) = \sum_{w = -\dot p}^{\dot p} \frac{a_w}{1 - (w / \dot p) x}\)。
這時候考慮求出每個 \(a_w\),可以發現做一個背包就行了,復雜度為 \(\mathcal O (n \dot p)\)。
(觀察背包轉移時的系數都是 \(\pm 1 / 2\),可以使用多項式 Exp 優化到 \(\mathcal O (n + \dot p \log \dot p)\),但是沒有必要)
對於 \(\displaystyle \mathbf{OGF} (g) = G(x) = \sum_{w = -\dot p}^{\dot p} \frac{b_w}{1 - (w / \dot p) x}\) 同理,我們需要求出每一個 \(b_w\)。
求出所有 \(a_w, b_w\) 之后,我們就掌握了 \(f, g\) 的一些性質,然后對於答案 \(h\),令其普通生成函數為 \(H\)。
則根據上面的解釋,有 \(H = F / G\),並且最終我們需要求出 \(H'(1)\)。
因為這里 \(F, G, H\) 都可能有無限項,所以要考慮通過 \(a_w, b_w\) 去求出答案。
考慮除法求導法則:\(\displaystyle H' = {(F / G)}' = \frac{F'G - G'F}{G^2}\)。
所以只要求出 \(F(1), G(1), F'(1), G'(1)\) 即可。
然而很可惜,我們發現因為存在 \(\displaystyle \frac{a_{\dot p}}{1 - (\dot p / \dot p) x}\) 這一項,所以 \(F, G, F', G'\) 在 \(x = 1\) 處不收斂。
我們知道答案一定收斂,所以考慮洛都可以洛做一點變換:把 \(F\) 和 \(G\) 都乘上 \((1 - x)\)。那么就有:
- \(F(1) = a_{\dot p}\)。
- \(\displaystyle F'(1) = \sum_{w = -\dot p}^{\dot p - 1} \frac{a_w}{w / \dot p - 1}\)。
- \(G(1) = b_{\dot p}\)。
- \(\displaystyle G'(1) = \sum_{w = -\dot p}^{\dot p - 1} \frac{b_w}{w / \dot p - 1}\)。
具體計算過程省略,就是按照求導的公式算而已。所以:
求出所有的 \(a_w, b_w\) 后按照此式計算即可,時間復雜度為 \(\mathcal O (n \dot p + \dot p \log mod)\),代碼如下:
#include <cstdio>
#include <algorithm>
typedef long long LL;
const int Mod = 998244353, Inv2 = (Mod + 1) / 2;
const int MN = 105, MP = 50005;
inline void Add(int &x, LL y) { x = (x + y) % Mod; }
inline int qPow(int b, int e) {
int a = 1;
for (; e; e >>= 1, b = (LL)b * b % Mod)
if (e & 1) a = (LL)a * b % Mod;
return a;
}
inline int gInv(int b) { return qPow(b, Mod - 2); }
int N, s[MN], p[MN], sump;
int _a[2][MP * 2], _b[2][MP * 2], *a[2] = {_a[0] + MP, _a[1] + MP}, *b[2] = {_b[0] + MP, _b[1] + MP};
int Ans;
int main() {
scanf("%d", &N);
for (int i = 1; i <= N; ++i) scanf("%d", &s[i]), s[i] = s[i] ? -1 : 1;
b[0][0] = a[0][0] = 1;
for (int i = 1; i <= N; ++i) {
scanf("%d", &p[i]);
for (int j = -sump - p[i]; j <= sump + p[i]; ++j) b[1][j] = a[1][j] = 0;
for (int j = -sump; j <= sump; ++j)
Add(a[1][j + p[i]], (LL)Inv2 * a[0][j]),
Add(a[1][j - p[i]], s[i] * (LL)Inv2 * a[0][j]),
Add(b[1][j + p[i]], (LL)Inv2 * b[0][j]),
Add(b[1][j - p[i]], (LL)Inv2 * b[0][j]);
sump += p[i];
std::swap(a[0], a[1]), std::swap(b[0], b[1]);
}
int isump = gInv(sump), *A = a[0], *B = b[0];
for (int j = -sump; j < sump; ++j)
Add(Ans, ((LL)A[j] * B[sump] - (LL)B[j] * A[sump]) % Mod * gInv((LL)j * isump % Mod - 1));
Ans = (LL)Ans * qPow(B[sump], Mod - 3) % Mod;
printf("%d\n", (Ans + Mod) % Mod);
return 0;
}