題目鏈接
題目大意
給了一個序列 \(A_1,A_2,...,A_N\),求出滿足 \(1\leq X_i\leq A_i\) 且相鄰元素不同的序列 \(X\) 的數量,答案模 \(998244353\) 。
\(2\leq N\leq 5\times 10^5\),\(1\leq A_i\leq 10^9\)
思路
考慮容斥,設 \(f_i\) 為前 \(i\) 位的答案,有轉移式:
\[f_i=\sum_{j=0}^{i-1}f_j\cdot \min_{j<k\leq i}\{A_k\}\cdot (-1)^{i-j-1} \]
注意到每次多一個 \(A_i\) 是會改變一個區間的轉移系數的,把 \((-1)^i\) 扔出來后,用線段樹維護所有地方的權值和(就是求和式的值),用棧記錄 \(A_i\) 影響了哪些位置的權值,每次把當前的 \(A_i\) 插入進去,更新線段樹即可。
時間復雜度 \(O(Nlog_2N)\)
Code
#include<iostream>
#include<stack>
#include<cstdio>
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,b,a) for(int i=(b);i>=(a);i--)
#define N 500500
#define M 2000021
#define mod 998244353
#define ll long long
using namespace std;
int a[N],n;
struct SegmentTree{
ll w,val,lazy;
}t[M];
stack<int> mn;
ll dp[N];
void pushdown(int x){
if(t[x].lazy){
t[x*2].w=(t[x*2].val*t[x].lazy)%mod;
t[x*2+1].w=(t[x*2+1].val*t[x].lazy)%mod;
t[x*2].lazy=t[x].lazy,t[x*2+1].lazy=t[x].lazy;
}
}
void update(int x,int l,int r,int a,int b,ll k){
if(l>=a&&r<=b){
t[x].lazy=k,t[x].w=(t[x].val*k)%mod;
return;
}
int mid=(l+r)>>1;
if(mid>=a)update(x*2,l,mid,a,b,k);
if(mid<b)update(x*2+1,mid+1,r,a,b,k);
t[x].w=(t[x*2].w+t[x*2+1].w)%mod;
t[x].val=(t[x*2].val+t[x*2+1].val)%mod;
}
void insert(int x,int l,int r,int loc,ll k){
if(l==r){
t[x].val=k,t[x].w=(t[x].val*t[x].lazy)%mod;
return;
}
pushdown(x);
int mid=(l+r)>>1;
if(mid>=loc)insert(x*2,l,mid,loc,k);
else insert(x*2+1,mid+1,r,loc,k);
t[x].w=(t[x*2].w+t[x*2+1].w)%mod;
t[x].val=(t[x*2].val+t[x*2+1].val)%mod;
}
int main(){
ios::sync_with_stdio(false);
cin>>n;
rep(i,1,n)cin>>a[i];
mn.push(0);
dp[0]=1;
insert(1,0,n,0,mod-dp[0]);
rep(i,1,n){
while(a[mn.top()]>a[i])mn.pop();
update(1,0,n,mn.top(),i-1,a[i]);
mn.push(i);
dp[i]=(i&1)?mod-t[1].w:t[1].w;
insert(1,0,n,i,(i&1)?dp[i]:mod-dp[i]);
}
cout<<dp[n]<<endl;
return 0;
}
思路2
然而這道題是可以做到線性的,題目這個一個點可以控制一段區間的性質,可以考慮用笛卡爾樹維護,和前面一樣建一個棧,不過棧中要儲存 \(A_i\) 下標和其控制的權值和兩樣東西,然后直接根據奇偶性計算就可以了,時間復雜度 \(O(N)\) 。
講起來有點玄,看代碼吧。
Code
#include<iostream>
#include<stack>
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,b,a) for(int i=(b);i>=(a);i--)
#define N 500500
#define ll long long
#define mod 998244353
#define fr first
#define sc second
using namespace std;
int a[N],n;
int main(){
ios::sync_with_stdio(false);
cin>>n;
rep(i,1,n)cin>>a[i];
ll val=0,sum=0;
stack<pair<ll,int> > s;
rep(i,1,n){
val=sum+(i==1);
(sum+=mod-val*a[i]%mod)%=mod;
while(!s.empty()&&s.top().second>a[i]){
pair<ll,int> cur=s.top();
s.pop();
(sum+=cur.fr*(cur.sc-a[i]))%=mod;
(val+=cur.fr)%=mod;
}
s.push({val,a[i]});
}
cout<<(n%2?mod-sum:sum)<<endl;
return 0;
}