UOJ#449. 【集訓隊作業2018】喂鴿子(期望dp)


題意

\(n\) 只鴿子,每只鴿子需要 \(k\) 粒玉米才能喂飽。問每次隨意喂給 \(n\) 個鴿子中的一個,期望多久所有鴿子都被喂飽。

對於 \(998244353\) 取模。

數據范圍

\(n \le 50, k \le 1000\)

題解

\(\mathcal O(n^2k \log k)\)

題目問的是最晚喂飽的鴿子,我們用 \(\min - \max\) 反演變成對於每個集合問最早被喂飽的鴿子。

不難發現只有集合大小是有用的,我們等價於算:

\[ans = \sum_{c = 1}^{n} (-1)^{c + 1} {n \choose c} g_c \]

我們只需要算 \(g_c\) 了,即大小為 \(c\) 的集合中最早被喂飽鴿子的期望時間。

我們考慮把期望轉成概率,即

\[g_c = \sum_{s \ge 1} P(x \ge s) \]

我們相當於要算到 \(s - 1​\) 時刻,\(c​\) 只鴿子都沒有被喂飽的概率。

我們為了算這個,輔助設 \(f_{c, s}​\)\(c​\) 只鴿子,喂了 \(s​\) 只玉米還沒有被喂飽的概率。

那么就有

\[g_c = \sum_{i \ge 1} \sum_{s = 1}^{i - 1} {i - 1 \choose s} f_{c, s} (\frac{n - c}{n})^{i - 1} \]

對於這種式子,我們通常需要交換和式,為了方便令 \(\displaystyle p = \frac{n - c}n\) 那么有

\[g_c = \sum_{s = 1}^{c(k - 1)} f_{c, s} \sum_{t \ge 1} {s + t \choose s} p^{i - 1} \]

對於后面的式子,我們不難想到一個經典的生成函數形式,即

\[(\frac 1{1 - x})^a = \sum_{i \ge 0} {a + i - 1 \choose a - 1} x^i \]

證明,考慮隔板法或者二項式展開。

那么其實就是

\[(\frac{1}{1 - p})^{s + 1} = (\frac{n}{c})^{1 + c}\\ \]

下面我們考慮如何計算 \(f_{c, s}\) ,我們枚舉第 \(c\) 個鴿子喂了幾顆玉米,那么就有

\[f_{c, s} = \sum_{i = 0}^{\min(s, k - 1)} {s \choose i} f_{c - 1, s - i} \frac{1}{n^i} \]

直接做是 \(\mathcal O(n^2k^2)\) 的,用 \(NTT\) 優化就可以做到 \(\mathcal O(n^2k \log k)\) 啦。

\(\mathcal O(n^2k)\)

其實有個更高妙的做法。

稱一粒玉米是有效玉米,當且僅當它被投喂給了一只沒有飽的鴿子。那么有效玉米序列的長度是固定的 \(n k\) 。現在考慮枚舉所有的有效玉米序列,計算對答案的貢獻。下面記 \(r_i\) 表示投喂第 \(i\) 粒有效玉米前已經有多少鴿子飽了。

那么對於一個玉米序列的貢獻其實就是

\[(\prod_{i = 1}^{nk} P_{r_i}) (\sum_{i = 1}^{nk}E_{r_i}) \]

其中 \(\displaystyle P_x = \frac 1 {n - x}, E_x = \frac n{n - x}\) 前面表示的這個序列的概率(注意每個鴿子是不同的),后一項表示相鄰兩個有效玉米之間需要投遞個數的期望。

直接 \(dp\) 似乎不太方便。因為無法確定下一粒玉米投喂后是否會是一只鴿子吃飽。注意到貢獻只和 \(r_i\) 有關,而一只鴿子吃飽前是不會對 \(r_i\) 產生影響的。所以可以認為一只鴿子吃飽前其有效玉米都是 **“白玉米” **。我們只在一只鴿子吃飽的時侯把白玉米染色。

這樣就可以 \(dp\) 了,先強制鴿子吃飽的順序是 \(1\)\(n\) ,最后乘 \(n!\) 。設 \(f_{m, c}\) 表示投喂了 \(m\) 粒有效玉米,前 \(c\) 只鴿子已經飽了的貢獻之和。\(g_{m, c}\) 表示概率之和。

推一下式子,那么有

\[\sum \prod_{i \le m} P_{r_i} P_{r_{m + 1}} (\sum_{i \le m} E_{r_i} + E_{r_m + 1})\\ = P_{r_{m + 1}}(\sum \prod_{i \le m} P_{r_i} \sum_{i \le m} E_{r_i}) + P_{r_{m + 1}} E_{r_{m + 1}} (\sum \prod_{i \le m} P_{r_i})\\ = P_{r_{m + 1}}f_{m, c} + P_{r_{m + 1}}E_{r_{m + 1}} g_{m, c} \]

顯然 \(r_{m+1} = c\) 。而新的概率之和只要簡單地乘個 \(P_{r_{m+1}}\) 就行了。

接下來有兩種轉移,第一種是加入一粒白玉米,這種直接做。另一種是在 \(m + 1\) 處有一只鴿子吃飽了,這種轉移需要乘上 \(\displaystyle {m−ck \choose k - 1}\) 表示給白玉米染色的方案數。最后有一只鴿子吃飽了 \(f_{nk, n} · n!\) 就是答案。

代碼

\(\mathcal O(n^2k \log k)\)

#include <bits/stdc++.h>

#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl

using namespace std;

template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; }

inline int read() {
	int x(0), sgn(1); char ch(getchar());
	for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
	for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
	return x * sgn;
}

void File() {
#ifdef zjp_shadow
	freopen ("449.in", "r", stdin);
	freopen ("449.out", "w", stdout);
#endif
}

const int N = 55, K = 1010;

namespace Computation {
	const int Mod = 998244353, g = 3;

	inline int fpm(int x, int power) {
		int res = 1;
		for (; power; power >>= 1, x = 1ll * x * x % Mod)
			if (power & 1) res = 1ll * res * x % Mod;
		return res;
	}

	inline void add(int &x, int y) { if ((x += y) >= Mod) x -= Mod; }
	inline void sub(int &x, int y) { if ((x -= y) < 0) x += Mod; }
	inline int mul(int x, int y) { return 1ll * x * y % Mod; }
#define div Div
	inline int div(int x, int y) { return 1ll * x * fpm(y, Mod - 2) % Mod; }

	int fac[N * K], ifac[N * K];
	void Fac_Init(int maxn) {
		fac[0] = ifac[0] = 1;
		For (i, 1, maxn) fac[i] = mul(fac[i - 1], i);
		ifac[maxn] = fpm(fac[maxn], Mod - 2);
		Fordown (i, maxn - 1, 1) ifac[i] = mul(ifac[i + 1], i + 1);
	}
	inline int comb(int n, int m) {
		if (n < 0 || m < 0 || n < m) return 0;
		return mul(mul(fac[n], ifac[n - m]), ifac[m]);
	}
}

namespace Poly {

	using namespace Computation;

	const int Maxn = 1 << 20;

	int powg[Maxn], invpowg[Maxn];

	void NTT_Init() {
		for (int i = 2; i < Maxn; i <<= 1)
			invpowg[i] = fpm(powg[i] = fpm(g, (Mod - 1) / i), Mod - 2);
	}

	int len, rev[Maxn];

	void NTT(int *P, int opt) {
		Rep (i, len) if (i < rev[i]) swap(P[i], P[rev[i]]);
		for (int i = 2, p = 1; i <= len; p = i, i <<= 1) {
			int Wi = opt == 1 ? powg[i] : invpowg[i];
			for (int j = 0; j < len; j += i)
				for (int k = 0, x = 1; k < p; ++ k) {
					int u = P[j + k], v = mul(x, P[j + k + p]);
					P[j + k] = (u + v) % Mod; 
					P[j + k + p] = (u - v + Mod) % Mod; 
					x = mul(x, Wi);
				}
		}
		if (!~opt) {
			int inv = fpm(len, Mod - 2);
			Rep (i, len) P[i] = mul(P[i], inv);
		}
	}

	int A[Maxn], B[Maxn], C[Maxn];
	void Mult(int *a, int *b, int *c, int na, int nb) {
		int nc = na + nb, bit = 0;
		for (len = 1; len <= nc; len <<= 1) ++ bit;
		Rep (i, len) A[i] = B[i] = 0;

		For (i, 0, na) A[i] = a[i];
		For (i, 0, nb) B[i] = b[i];

		Rep (i, len) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
		NTT(A, 1); NTT(B, 1);
		Rep (i, len) C[i] = mul(A[i], B[i]);
		NTT(C, -1);
		For (i, 0, nc) c[i] = C[i];
	}

}

using namespace Computation;

int n, k, f[N][N * K * 2], invn[N * K];

int a[N * K * 2], b[N * K * 2];

int main () {

	File();

	n = read(); k = read();

	Fac_Init(n * k);

	invn[0] = 1; invn[1] = fpm(n, Mod - 2);
	For (i, 2, k) invn[i] = mul(invn[i - 1], invn[1]);

	int ans = 0;

	Poly :: NTT_Init();

	f[0][0] = 1;
	For (c, 1, n) {
		For (s, 0, (c - 1) * (k - 1)) 
			a[s] = mul(f[c - 1][s], ifac[s]);
		For (i, 0, k - 1)
			b[i] = mul(invn[i], ifac[i]);

		Poly :: Mult(a, b, f[c], (c - 1) * (k - 1), k - 1);
		For (s, 0, c * (k - 1))
			f[c][s] = mul(f[c][s], fac[s]);
	}

	For (c, 1, n) {
		int res = 0, base = div(n, c), coef = base;
		For (s, 0, c * (k - 1)) 
			add(res, mul(f[c][s], coef)), coef = mul(coef, base);
		add(ans, mul(comb(n, c), mul(c & 1 ? 1 : Mod - 1, res)));
	}
	printf ("%d\n", ans);

	return 0;

}

\(\mathcal O(n^2k)\)

#include <bits/stdc++.h>

#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl

using namespace std;

template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; }

inline int read() {
	int x(0), sgn(1); char ch(getchar());
	for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
	for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
	return x * sgn;
}

void File() {
#ifdef zjp_shadow
	freopen ("449.in", "r", stdin);
	freopen ("449.out", "w", stdout);
#endif
}

const int N = 55, K = 1010;

namespace Computation {
	const int Mod = 998244353;

	inline int fpm(int x, int power) {
		int res = 1;
		for (; power; power >>= 1, x = 1ll * x * x % Mod)
			if (power & 1) res = 1ll * res * x % Mod;
		return res;
	}

	inline void add(int &x, int y) { if ((x += y) >= Mod) x -= Mod; }
	inline void sub(int &x, int y) { if ((x -= y) < 0) x += Mod; }
#define plus Plus
	inline int plus(int x, int y) { return (x += y) >= Mod ? x - Mod : x; }
	inline int mul(int x, int y) { return 1ll * x * y % Mod; }
#define div Div
	inline int div(int x, int y) { return 1ll * x * fpm(y, Mod - 2) % Mod; }

	int fac[N * K], ifac[N * K];
	void Fac_Init(int maxn) {
		fac[0] = ifac[0] = 1;
		For (i, 1, maxn) fac[i] = mul(fac[i - 1], i);
		ifac[maxn] = fpm(fac[maxn], Mod - 2);
		Fordown (i, maxn - 1, 1) ifac[i] = mul(ifac[i + 1], i + 1);
	}
	inline int comb(int n, int m) {
		if (n < 0 || m < 0 || n < m) return 0;
		return mul(mul(fac[n], ifac[n - m]), ifac[m]);
	}
}

using namespace Computation;

int n, k, f[N * K][N], g[N * K][N];

int P[N], E[N], inv[N];

int main () {

	File();

	n = read(); k = read();

	Fac_Init(n * k);

	inv[1] = 1;
	For (i, 2, n)
		inv[i] = mul(inv[Mod % i], (Mod - Mod / i));
	For (i, 0, n)
		P[i] = inv[n - i], E[i] = mul(n, inv[n - i]);

	f[0][0] = 0; g[0][0] = 1;
	Rep (i, n * k) For (j, 0, i / k) if (g[i][j]) {
		int coefg = mul(g[i][j], P[j]),
			coeff = plus(mul(P[j], f[i][j]), mul(mul(P[j], E[j]), g[i][j])),
			coef = comb(i - j * k, k - 1);
		add(f[i + 1][j], coeff);
		add(g[i + 1][j], coefg);
		add(f[i + 1][j + 1], mul(coef, coeff));
		add(g[i + 1][j + 1], mul(coef, coefg));
	}
	printf ("%d\n", mul(f[n * k][n], fac[n]));

	return 0;

}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM