Loj #2541. 「PKUWC2018」獵人殺
好巧妙的題!
游戲過程中,概率的分母一直在變化,所以就非常的不可做。
所以我們將問題轉化一下:我們可以重復選擇相同的獵人,只不過在一個獵人被選擇了過后我們就給他打上標記,再次選擇他的時候就無效。這樣與原問題是等價的。
證明:
設\(sum=\sum_iw_i,kill=\sum_{i被殺死了}w_i\)。
攻擊到未被殺死的獵人\(i\)的概率為\(P\)。
則根據題意\(P=\frac{w_i}{sum-kill}\)。
問題轉化后:
          \[\\P=\frac{kill}{sum}P+\frac{w_i}{sum}\\ \Rightarrow P=\frac{w_i}{sum-kill}。 \] 
        
 
         
        然后我們考慮容斥:枚舉集合\(T\)中的獵人一定在\(1\)之后被殺死,其他獵人隨意。
我們設\(S=\sum_{i\in T}w_i\)
則:
          \[\displaystyle \begin{align} ans&=(-1)^{|T|}\sum_{i=0}^{\infty}(1-\frac{S+w_1}{sum})^i\frac{w_1}{sum} \\&=(-1)^{|T|}\frac{1}{1-(1-\frac{S+w_1}{sum})}\frac{w_1}{sum} \\&=(-1)^{|T|}\frac{w_1}{w_1+S} \end{align} \] 
        
 
         
         
         
        然后我們就可以用背包背出所有\(\sum w_i\)恰好為\(S\)的帶上容斥系數的方案數。
但復雜度有點高,於是我們考慮用生成函數來優化。這道題的生成函數還是比較簡單,就是\(\Pi (1-x^{w_i})\)。用分治\(NTT\)實現。
代碼:
#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
#define N 100005
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
ll ksm(ll t,ll x)  {
	ll ans=1;
	for(;x;x>>=1,t=t*t%mod)
		if(x&1) ans=ans*t%mod;
	return ans;
}
int rev[N<<2];
void NTT(ll *a,int d,int flag) {
	static ll G=3;
	int n=1<<d;
	for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
	for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int s=1;s<=d;s++) {
		int len=1<<s,mid=len>>1;
		ll w=flag==1?ksm(G,(mod-1)/len):ksm(G,mod-1-(mod-1)/len);
		for(int i=0;i<n;i+=len) {
			ll t=1;
			for(int j=0;j<mid;j++,t=t*w%mod) {
				ll u=a[i+j],v=a[i+j+mid]*t%mod;
				a[i+j]=(u+v)%mod;
				a[i+j+mid]=(u-v+mod)%mod;
			}
		}
	}
	if(flag==-1) {
		ll inv=ksm(n,mod-2);
		for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
	}
}
ll f[N];
int n,w[N];
int sum[N];
int binary(int lx,int rx) {
	int mid,l=lx,r=rx;
	while(l<r) {
		mid=l+r+1>>1;
		if(sum[mid]-sum[lx-1]<=sum[rx]-sum[mid]) l=mid;
		else r=mid-1;
	}
	return l;
}
void solve(int l,int r,ll *f) {
	if(l==r) {
		f[0]=1;
		f[w[l]]=mod-1;
		return ;
	}
	int mid=binary(l,r);
	const int d=ceil(log2(sum[r]-sum[l-1]))+1;
	ll *a=new ll[(1<<d)+5],*b=new ll[(1<<d)+5];
	for(int i=0;i<(1<<d);i++) a[i]=b[i]=0;
	solve(l,mid,a),solve(mid+1,r,b);
	NTT(a,d,1),NTT(b,d,1);
	for(int i=0;i<(1<<d);i++) a[i]=a[i]*b[i]%mod;
	NTT(a,d,-1);
	for(int i=0;i<(1<<d);i++) f[i]=a[i];
	delete a;
	delete b;
}
ll ans;
int main() {
	n=Get();
	for(int i=1;i<=n;i++) w[i]=Get();
	sort(w+2,w+1+n);
	for(int i=2;i<=n;i++) sum[i]=sum[i-1]+w[i];
	solve(2,n,f);
	int tot=sum[n]-sum[1];
	for(int i=0;i<=tot;i++) {
		(ans+=f[i]*w[1]%mod*ksm(w[1]+i,mod-2)%mod)%=mod;
	}
	cout<<ans;
	return 0;
}
 
         
         
       