題面
給 \(n\) 個數的序列 \(a_i\),求有多少種 \(n\) 個數的排列 \(p_i\),使得
\[\frac{a_{p_i}}{\max_{j=1}^{i-1} a_{p_j}}\notin \left(\frac 12, 2\right) \]答案膜 \(998244353\)。
數據范圍:\(2\le n\le 5000\)。
題解
您會很高興地發現當前 \(a_{p_i}\) 大於前面最大值的總次數不會超過 \(\log\) 次,可惜這沒用。
但是考慮這個前綴最大值的序列,卻可以發現這東西改變序列便是不同的。
將 \(a_i\) 排序 ,考慮用 \(a_i\) 填滿一個長度為 \(n\) 的空位序列。
設 \(f_i\) 表示當前最大值是 \(a_i\) 的填充方案數,\(lim_i\) 表示最大的 \(j\) 滿足 \(2a_j\le a_i\)。
先把這個最大值填當前的第一個空位。
由於最大值只會增大,對於所有 \(2a_j\le a_i\) 必定滿足 \(2a_j\) 小於未來的最大值,所以可以先把 \(2a_j\le a_i\) 的 未使用的 都填在最大值后面的空位上(不需要連續),順序重要。
為什么不需要多設一維表示當前放的個數呢?因為 \(f_i\) 放的個數必定是 \(1+lim_i\),這個 \(1\) 指的是自己。
所以轉移方程是:
\[f_i=\sum_{j=0}^{lim_i} f_j\cdot A(n-2-lim_j,lim_i-lim_j-1) \]
這里 \(n-2-lim_j\) 是總空位數減去 \(f_j\) 用的空位數減去當前最大值 \(1\) 個,\(lim_i-lim_j-1\) 是 \(2a_t\le a_i\) 的未使用的 \(a_t\) 個數。
時間復雜度 \(\Theta(n^2)\),空間復雜度 \(\Theta(n)\)。
代碼
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair((a),(b))
#define x first
#define y second
#define bg begin()
#define ed end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
#define R(i,a,b) for(int i=(a),i##E=(b);i<i##E;i++)
#define L(i,a,b) for(int i=(b)-1,i##E=(a)-1;i>i##E;i--)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int N=5000;
int n,a[N],f[N+1],lim[N+1];
//Math
const int mod=998244353;
void fmod(int&x){x+=mod&x>>31;}
int Pow(int a,int x){
int res=1; for(;x;x>>=1,a=1ll*a*a%mod)
if(x&1) res=1ll*res*a%mod; return res;
}
int fac[N+1],ifac[N+1];
void math_init(){
fac[0]=1; R(i,1,n+1) fac[i]=1ll*fac[i-1]*i%mod;
ifac[n]=Pow(fac[n],mod-2);
L(i,0,n) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
}
int A(int u,int v){
if(u<0||v<0||v>u) return 0;
return 1ll*fac[u]*ifac[u-v]%mod;
}
//Main
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n,math_init();
R(i,0,n) cin>>a[i]; sort(a,a+n);
R(i,0,n){
int l=-1,r=i+1;
while(r-l>1){
int mid=(l+r)>>1;
if(a[mid]*2>a[i]) r=mid;
else l=mid;
}
lim[i+1]=r;
}
f[0]=1,lim[0]=-1;
R(i,1,n+1)R(j,0,lim[i]+1)
fmod(f[i]+=1ll*f[j]*A(n-2-lim[j],lim[i]-lim[j]-1)%mod-mod);
// R(i,0,n+1) cout<<f[i]<<" ";cout<<'\n';
if(lim[n]==n-1) cout<<f[n]<<'\n';
else cout<<0<<'\n';
return 0;
}
祝大家學習愉快!