A simple introduction to "Segment tree beats"
區間最值
以「 區間取 \(\min\),查詢區間和」為例,線段樹節點需儲存 \(mx,smx,cnt,sum\) 四個信息,即最大值,嚴格次大值,最大值個數,區間和。更新信息:
void update(int x, int l, int r, int L, int R, int val){
if(t[x].mx <= val) return;
if(l >= L && r <= R && t[x].smx < val){ addtag(x, val); return; }
int mid = (l+r)>>1;
pushdown(x);
if(mid >= L) update(x*2, l, mid, L, R, val);
if(mid < R) update(x*2+1, mid+1, r, L, R, val);
pushup(x);
}
在只有區間 \(\min,\max\) 操作時,時間復雜度為 \(O(n\log n)\),當有其他區間修改操作時,時間復雜度為 \(O(n\log^2n)\),但實際表現和 \(1\) 個 \(\log\) 差不多。
這樣的處理方式本質上就是對最大值或最小值專門進行維護,於是可以將信息分成兩類,最值和非最值,兩種分開維護,而區間 \(\min,\max\) 操作可以轉化為對最值的區間加減操作。
歷史最值
此處考慮的完整問題是:區間取 \(\min,\max\)、區間加、區間歷史最大值、區間歷史最大值之和。記 \(A_i\) 為原數組,\(B_i\) 為歷史最大值數組。
區間加,區間最大歷史最大值:
在加法懶標記 \(Add\) 之外再維護一個歷史最大加減標記 \(Pre\),表示從上一次標記下傳至今 \(Add\) 達到過的最大值,合並:\(Pre_{son}=\max(Pre_{son},Add_{son}+Pre_x),\;Add_{son}=Add_{son}+Add_x\),\(O(n\log n)\) 。
只有區間查詢歷史最小值/最大值:
將最值和非最值分開維護,則操作全部轉化為區間加減操作,可以沿用 \(Pre\) 懶標記進行維護,\(O(n\log^2n)\) 。
無區間 \(\min,\max\) 操作:
- 區間歷史最大值 & 歷史最大值之和:記 \(C_i=A_i-B_i\),區間加轉化為 \(C_i\rightarrow \min(C_i+x,0)\),\(O(n\log^2n)\)
- 區間歷史版本之和:令 \(t\) 為當前已結束的操作數,記 \(C_i=B_i+t\cdot A_i\),區間加轉化為 \(C_i\rightarrow C_i-x\cdot t\),\(O(n\log n)\)。
有區間 \(\min,\max\) 操作:
將最值和非最值分開維護 \(C_i\),則轉化為上面的「無區間 \(\min,\max\) 操作」問題,在分開維護部分會多出 \(1\) 個 \(\log\),從而「區間歷史最大值」和「歷史最大值之和」 \(O(n\log^3n)\),「區間歷史版本之和」 \(O(n\log^2n)\) 。
一些簡單的例題&實現
洛谷P6242 【模板】線段樹 3
即前面區間最值中的「無區間 \(\min,\max\) 操作」,線段樹節點維護 \(A_i\) 最大值、\(B_i\) 最大值、\(A_i\) 嚴格次小值、\(A_i\) 最大值個數、\(A_i\) 最大值和非最大值的加法標記,\(B_i\) 最大值和非最大值的歷史加法標記(即 \(Pre\)),然后就可以了。
#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 505000
#define ll long long
#define Inf 0x7fffffff
using namespace std;
int n, m, a[N];
struct SegmentTreeBeats{
struct node{
int mx_a, smx, cnt, mx_b;
ll sum, add_a1, add_a2, add_b1, add_b2;
// add_a1, add_a2: lazy add tag add_b1, add_b2: historical tag
} t[N<<2];
void pushup(int x){
t[x].mx_a = max(t[x*2].mx_a, t[x*2+1].mx_a);
t[x].mx_b = max(t[x*2].mx_b, t[x*2+1].mx_b);
t[x].sum = t[x*2].sum + t[x*2+1].sum;
if(t[x*2].mx_a == t[x*2+1].mx_a) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
else if(t[x*2].mx_a > t[x*2+1].mx_a) t[x].smx = max(t[x*2].smx, t[x*2+1].mx_a);
else t[x].smx = max(t[x*2].mx_a, t[x*2+1].smx);
t[x].cnt = (t[x*2].mx_a >= t[x*2+1].mx_a) * t[x*2].cnt + (t[x*2+1].mx_a >= t[x*2].mx_a) * t[x*2+1].cnt;
}
// a: to max, b: to historical max, c: to non max, d: to historical non max
void addtag(int x, int l, int r, ll a, ll b, ll c, ll d){
t[x].sum += a*t[x].cnt + c*(r-l+1-t[x].cnt);
t[x].mx_b = max((ll)t[x].mx_b, t[x].mx_a + b);
t[x].add_b1 = max(t[x].add_b1, t[x].add_a1 + b);
t[x].add_b2 = max(t[x].add_b2, t[x].add_a2 + d);
t[x].mx_a += a, t[x].add_a1 += a, t[x].add_a2 += c;
if(t[x].smx > -Inf) t[x].smx += c;
}
void pushdown(int x, int l, int r){
int mid = (l+r)>>1, mx = max(t[x*2].mx_a, t[x*2+1].mx_a);
ll add_a1 = t[x].add_a1, add_a2 = t[x].add_a2, add_b1 = t[x].add_b1, add_b2 = t[x].add_b2;
if(t[x*2].mx_a == mx) addtag(x*2, l, mid, add_a1, add_b1, add_a2, add_b2);
else addtag(x*2, l, mid, add_a2, add_b2, add_a2, add_b2);
if(t[x*2+1].mx_a == mx) addtag(x*2+1, mid+1, r, add_a1, add_b1, add_a2, add_b2);
else addtag(x*2+1, mid+1, r, add_a2, add_b2, add_a2, add_b2);
t[x].add_a1 = t[x].add_b1 = t[x].add_a2 = t[x].add_b2 = 0;
}
void build(int x, int l, int r){
if(l == r){
t[x].sum = t[x].mx_a = t[x].mx_b = a[l];
t[x].smx = -Inf, t[x].cnt = 1;
return;
}
int mid = (l+r)>>1;
build(x*2, l, mid), build(x*2+1, mid+1, r);
pushup(x);
}
void update_add(int x, int l, int r, int L, int R, int k){
if(l >= L && r <= R){ addtag(x, l, r, k, k, k, k); return; }
int mid = (l+r)>>1;
pushdown(x, l, r);
if(mid >= L) update_add(x*2, l, mid, L, R, k);
if(mid < R) update_add(x*2+1, mid+1, r, L, R, k);
pushup(x);
}
void update_min(int x, int l, int r, int L, int R, int val){
if(t[x].mx_a <= val) return;
if(l >= L && r <= R && t[x].smx < val){ addtag(x, l, r, val-t[x].mx_a, val-t[x].mx_a, 0, 0); return; }
int mid = (l+r)>>1;
pushdown(x, l, r);
if(mid >= L) update_min(x*2, l, mid, L, R, val);
if(mid < R) update_min(x*2+1, mid+1, r, L, R, val);
pushup(x);
}
ll query(int x, int l, int r, int L, int R, int id){
if(l >= L && r <= R) return id == 1 ? t[x].sum : (id == 2 ? t[x].mx_a : t[x].mx_b);
int mid = (l+r)>>1;
pushdown(x, l, r);
ll a = (mid >= L) ? query(x*2, l, mid, L, R, id) : (id == 1 ? 0 : -Inf);
ll b = (mid < R) ? query(x*2+1, mid+1, r, L, R, id) : (id == 1 ? 0 : -Inf);
return (id == 1 ? a+b : max(a, b));
}
} T;
int main(){
ios::sync_with_stdio(false);
cin>>n>>m;
rep(i,1,n) cin>>a[i];
T.build(1, 1, n);
int type, l, r, k;
while(m--){
cin>>type>>l>>r;
switch(type){
case 1 : cin>>k, T.update_add(1, 1, n, l, r, k); break;
case 2 : cin>>k, T.update_min(1, 1, n, l, r, k); break;
default : cout<< T.query(1, 1, n, l, r, type-2) <<endl;
}
}
return 0;
}
Codeforces 1290E Cartesian Tree
給定一個長為 \(n\) 的排列的 \(a_i\),對於每個 \(k\in[1,n]\),以前 \(k\) 小的值(保持排列中的順序)建笛卡爾樹,求出所有子樹大小之和。\(1\leq n\leq 150000\)
記 \(l_i,r_i\) 分別為 \(a_i\) 從左右得到的最后一個比其小的數的位置(保留前 \(k\) 小的值組成序列的位置),則節點 \(i\) 的子樹大小即 \(r_i-l_i+1\),於是答案即為 \(\sum r_i-\sum l_i+k\)。分開維護 \(l_i\) 和 \(r_i\),以 \(r_i\) 為例,容易發現當 \(k\) 增加 \(1\),\(k+1\) 插入到序列中時,\(pos_{k+1}\) 左側的 \(r_i\rightarrow \min(r_i,pos_{k+1})\),而右側的 \(r_i\) 因為左邊添加了元素,\(r_i\rightarrow r_i+1\),\(l_i\) 也是類似的處理。所以我們使用 Segment Tree Beats 維護區間 \(\min\)、區間加、全局和即可,\(O(n\log^2n)\) 。
#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 160000
#define Inf 0x3f3f3f3f
#define ll long long
#define lowbit(x) (x&-x)
using namespace std;
int n;
int a[N], pos[N];
struct Segment_Tree_Beats{
struct node{
int mx, smx, cnt, num, tag = -1, lazy;
ll sum;
} t[N<<2];
void pushup(int x){
t[x].mx = max(t[x*2].mx, t[x*2+1].mx);
t[x].num = t[x*2].num + t[x*2+1].num;
t[x].sum = t[x*2].sum + t[x*2+1].sum;
if(t[x*2].mx == t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
else if(t[x*2].mx > t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].mx);
else t[x].smx = max(t[x*2].mx, t[x*2+1].smx);
t[x].cnt = (t[x*2].mx >= t[x*2+1].mx) * t[x*2].cnt + (t[x*2+1].mx >= t[x*2].mx) * t[x*2+1].cnt;
}
void pushadd(int x, int k){
t[x].sum += (ll)k * t[x].num;
if(t[x].num){
t[x].mx += k, t[x].lazy += k;
if(t[x].smx > 0) t[x].smx += k;
if(~t[x].tag) t[x].tag += k;
}
}
void pushmin(int x, int k){
if(t[x].mx <= k) return;
t[x].sum -= (ll)t[x].cnt * (t[x].mx - k);
t[x].mx = k, t[x].tag = k;
}
void pushdown(int x){
if(t[x].lazy) pushadd(x*2, t[x].lazy), pushadd(x*2+1, t[x].lazy);
if(~t[x].tag) pushmin(x*2, t[x].tag), pushmin(x*2+1, t[x].tag);
t[x].tag = -1, t[x].lazy = 0;
}
void insert(int x, int l, int r, int pos, int val){
if(l == r){ t[x].sum = t[x].mx = val, t[x].cnt = t[x].num = 1; return; }
int mid = (l+r)>>1;
pushdown(x);
if(mid >= pos) insert(x*2, l, mid, pos, val);
else insert(x*2+1, mid+1, r, pos, val);
pushup(x);
}
void update(int x, int l, int r, int L, int R, int k, int id){
if(L > R) return;
if(id && t[x].mx <= k) return;
if(l >= L && r <= R){
if(id && t[x].smx < k){ pushmin(x, k); return; }
else if(!id){ pushadd(x, k); return; }
}
int mid = (l+r)>>1;
pushdown(x);
if(mid >= L) update(x*2, l, mid, L, R, k, id);
if(mid < R) update(x*2+1, mid+1, r, L, R, k, id);
pushup(x);
}
} LB, RB;
struct Fenwick{
int t[N];
void update(int pos, int k){
while(pos <= n) t[pos] += k, pos += lowbit(pos);
}
int get(int pos){
int ret = 0;
while(pos) ret += t[pos], pos -= lowbit(pos);
return ret;
}
} T;
int main(){
ios::sync_with_stdio(false);
cin>>n;
rep(i,1,n) cin>>a[i], pos[a[i]] = i;
rep(i,1,n){
RB.update(1, 1, n, pos[i]+1, n, 1, 0);
RB.update(1, 1, n, 1, pos[i]-1, T.get(pos[i]), 1);
LB.update(1, 1, n, 1, n-pos[i]+1, -1, 0);
LB.update(1, 1, n, 1, n-pos[i]+1, n-T.get(pos[i])-1, 1);
RB.insert(1, 1, n, pos[i], i), LB.insert(1, 1, n, n-pos[i]+1, n);
cout<< RB.t[1].sum - ((ll)i*(n+1) - LB.t[1].sum) + i <<endl;
T.update(pos[i], 1);
}
return 0;
}
Codeforces 1572F Stations
有 \(n\) 個城市,每個城市有兩個屬性 \(h_i,w_i\),第 \(i\) 個城市的廣播可以覆蓋到所有 \([i,w_i]\) 中滿足 \(\max_{i<k\leq j}\{h_k\}<h_i\) 的城市 \(j\)。
初始時所有 \(h_i=0,w_i=i\),每次操作可以使得城市 \(c_i\) 的 \(h\) 成為全局嚴格最大值,並修改 \(w_{c_i}\),或者詢問 \([l,r]\) 中,對於每個城市來說廣播可以覆蓋到它的城市個數之和。
\(1\leq n\leq 2\times 10^5\)
顯然每個城市覆蓋的區域是一個從 \(i\) 開始的區間,而維護區間右端點是一個「 單點修改,區間取 \(\min\)」問題,另外用一棵線段樹維護區間覆蓋,注意到 Segment Tree Beats 里值的修改是直接進行的,所以修改時可以順帶在另一棵線段樹上進行區間修改,即打 \(tag\) 時維護即可。詢問時直接在線段樹查詢區間和。時間復雜度 \(O(n\log^2n)\) 。
#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 200021
#define ll long long
using namespace std;
int n, q;
struct SegmentTree{
struct node{
ll sum, lazy;
} t[N<<2];
void pushdown(int x, int l, int r){
int mid = (l+r)>>1;
t[x*2].sum += (mid-l+1) * t[x].lazy, t[x*2].lazy += t[x].lazy;
t[x*2+1].sum += (r-mid) * t[x].lazy, t[x*2+1].lazy += t[x].lazy;
t[x].lazy = 0;
}
void update(int x, int l, int r, int L, int R, ll k){
if(L > R) return;
if(l >= L && r <= R){ t[x].sum += (r-l+1) * k, t[x].lazy += k; return; }
int mid = (l+r)>>1;
pushdown(x, l, r);
if(mid >= L) update(x*2, l, mid, L, R, k);
if(mid < R) update(x*2+1, mid+1, r, L, R, k);
t[x].sum = t[x*2].sum + t[x*2+1].sum;
}
ll get(int x, int l, int r, int L, int R){
if(l >= L && r <= R) return t[x].sum;
int mid = (l+r)>>1; ll ret = 0;
pushdown(x, l, r);
if(mid >= L) ret += get(x*2, l, mid, L, R);
if(mid < R) ret += get(x*2+1, mid+1, r, L, R);
return ret;
}
} BIT;
struct Segment_Tree_Beats{
struct node{
int mx, smx, cnt, tag = -1;
} t[N<<2];
void pushup(int x){
t[x].mx = max(t[x*2].mx, t[x*2+1].mx);
if(t[x*2].mx == t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
else if(t[x*2].mx > t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].mx);
else t[x].smx = max(t[x*2].mx, t[x*2+1].smx);
t[x].cnt = (t[x*2].mx >= t[x*2+1].mx) * t[x*2].cnt + (t[x*2+1].mx >= t[x*2].mx) * t[x*2+1].cnt;
}
void addtag(int x, int k, bool frt){
if(t[x].mx <= k) return;
if(frt) BIT.update(1, 1, n, k+1, t[x].mx, -t[x].cnt);
t[x].mx = t[x].tag = k;
}
void pushdown(int x){
if(~t[x].tag)
addtag(x*2, t[x].tag, 0), addtag(x*2+1, t[x].tag, 0);
t[x].tag = -1;
}
void insert(int x, int l, int r, int pos, int val){
if(l == r){
BIT.update(1, 1, n, l, t[x].mx, -t[x].cnt), BIT.update(1, 1, n, l, val, 1);
t[x].mx = val, t[x].cnt = 1; return;
}
int mid = (l+r)>>1;
pushdown(x);
if(mid >= pos) insert(x*2, l, mid, pos, val);
else insert(x*2+1, mid+1, r, pos, val);
pushup(x);
}
void update(int x, int l, int r, int L, int R, int k){
if(t[x].mx <= k || L > R) return;
if(l >= L && r <= R && t[x].smx < k){ addtag(x, k, 1); return; }
int mid = (l+r)>>1;
pushdown(x);
if(mid >= L) update(x*2, l, mid, L, R, k);
if(mid < R) update(x*2+1, mid+1, r, L, R, k);
pushup(x);
}
} T;
int main(){
ios::sync_with_stdio(false);
cin>>n>>q;
rep(i,1,n) T.insert(1, 1, n, i, i);
int type, c, g, l, r;
while(q--){
cin>>type;
if(type == 1){
cin>>c>>g;
T.insert(1, 1, n, c, g), T.update(1, 1, n, 1, c-1, c-1);
} else cin>>l>>r, cout<< BIT.get(1, 1, n, l, r) <<endl;
}
return 0;
}