題意:從原點出發,走\(n\)次,每次往四個方向中隨機一個走,走每個方向有個概率,求所有方案走到過的點數的方差。
題解:orz kczno1
\(E*all=\sum (a_i-avg)^2*all=\sum a_i^2-2*\sum a_i*avg+avg^2*all*all\)。把\(avg=\frac{\sum a_i}{all}\)代入,可以發現要求的就是\(\sum a_i\)和\(\sum a_i^2\)。以下為了方便求概率。
設\(f_{i,x,y}\)為走了\(i\)步第一次走到\(x,y\)的概率,\(g_{i,x,y}\)為走了\(i\)步最后在\(x,y\)的概率,\(g\)可以暴力DP,\(f\)用\(g\)容斥一下就行了。\(\sum a_i\)顯然可以拆成對於每個點,求出經過它的概率並求和,就是\(\sum f\)。
考慮怎樣求\(\sum a_i^2\)。可以用類似管道取珠的方法,把\(x^2\)拆成\(\binom{x}{2}*2+x\),那么只要求\(\sum \binom{x}{2}\)。設\(h_{i,x,y}\)為走了\(i\)步第一次走到某個點\(a,b\)並且之前走到過\(a-x,b-y\)的概率和,那么\(h_{i,x,y}=\sum_{j<i,a,b}f_{j,a,b}f_{i-j,x,y}-\sum_{j<i}h_{j,-x,-y}f_{i-j,x,y}\),減掉的是第一次到達某個點\(a,b\)之前到達過\(a+x,b+y\)的概率。
再加一些卡常就uojrk2了。注意用16次一取膜優化的時候要保證每一次加的都為非負數,否則要改為(signed) long long並且改為8次一取膜。一開始因為這個WA飛了。
#include<bits/stdc++.h>
using namespace std;
const int mod = 998244353;
const int N = 110;
typedef long long ll;
int qpow(int a, int b) {
int ret = 1;
while(b) {
if(b & 1) {
ret = 1ll * ret * a % mod;
}
a = 1ll * a * a % mod, b >>= 1;
}
return ret;
}
int n, w[4], tt = 0, f[N][N * 2][N * 2], g[N][N * 2][N * 2], h[N][N * 2][N * 2], sf[N], dx[4] = {0, 0, -1, 1}, dy[4] = {1, -1, 0, 0}, s1 = 0, s2 = 0;
int main() {
cin >> n;
for(int i = 0; i < 4; i++) {
cin >> w[i], tt += w[i];
}
for(int i = 0; i < 4; i++) {
w[i] = 1ll * w[i] * qpow(tt, mod - 2) % mod;
}
f[0][N][N] = g[0][N][N] = sf[0] = 1;
for(int i = 1; i <= n; i++)
for(int x = -i; x <= i; x++)
for(int y = -i; y <= i; y++) {
if(abs(x) + abs(y) > i) {
continue;
}
ll sum = 0;
for(int k = 0; k < 4; k++) {
sum += 1ll * g[i - 1][x - dx[k] + N][y - dy[k] + N] * w[k];
}
g[i][x + N][y + N] = sum % mod;
sum = g[i][x + N][y + N];
for(int k = 0; k < i; k++) {
sum -= 1ll * f[k][x + N][y + N] * g[i - k][N][N];
if((k & 7) == 7) {
sum %= mod;
}
}
f[i][x + N][y + N] = (sum % mod + mod) % mod;
sf[i] = (sf[i] + f[i][x + N][y + N]) % mod;
sum = 0;
for(int k = 0; k < i; k++) {
sum += 1ll * (sf[k] - h[k][-x + N][-y + N]) * f[i - k][x + N][y + N];
if((k & 7) == 7) {
sum %= mod;
}
}
h[i][x + N][y + N] = (sum % mod + mod) % mod;
}
for(int i = 0; i <= n; i++) {
s1 = (s1 + sf[i]) % mod;
}
for(int i = 0; i <= n; i++)
for(int x = -i; x <= i; x++)
for(int y = -i; y <= i; y++) {
if(abs(x) + abs(y) > i) {
continue;
}
s2 = (s2 + h[i][x + N][y + N]) % mod;
}
tt = qpow(tt, n);
s1 = 1ll * s1 * tt % mod, s2 = 1ll * s2 * tt % mod;
cout << ((1ll * tt * (2ll * s2 + s1) - 1ll * s1 * s1) % mod + mod) % mod;
return 0;
}