Segment Tree Beats 學習筆記


2016集訓隊論文 吉如一《區間最值操作與歷史最值問題》

A simple introduction to "Segment tree beats"

區間最值

以「 區間取 \(\min\),查詢區間和」為例,線段樹節點需儲存 \(mx,smx,cnt,sum\) 四個信息,即最大值,嚴格次大值,最大值個數,區間和。更新信息:

void update(int x, int l, int r, int L, int R, int val){
    if(t[x].mx <= val) return;
    if(l >= L && r <= R && t[x].smx < val){ addtag(x, val); return; }
    int mid = (l+r)>>1;
    pushdown(x);
    if(mid >= L) update(x*2, l, mid, L, R, val);
    if(mid < R) update(x*2+1, mid+1, r, L, R, val);
    pushup(x);
}

在只有區間 \(\min,\max\) 操作時,時間復雜度為 \(O(n\log n)\),當有其他區間修改操作時,時間復雜度為 \(O(n\log^2n)\),但實際表現和 \(1\)\(\log\) 差不多。

這樣的處理方式本質上就是對最大值或最小值專門進行維護,於是可以將信息分成兩類,最值和非最值,兩種分開維護,而區間 \(\min,\max\) 操作可以轉化為對最值的區間加減操作。

歷史最值

此處考慮的完整問題是:區間取 \(\min,\max\)、區間加、區間歷史最大值、區間歷史最大值之和。記 \(A_i\) 為原數組,\(B_i\) 為歷史最大值數組。

區間加,區間最大歷史最大值:

  在加法懶標記 \(Add\) 之外再維護一個歷史最大加減標記 \(Pre\),表示從上一次標記下傳至今 \(Add\) 達到過的最大值,合並:\(Pre_{son}=\max(Pre_{son},Add_{son}+Pre_x),\;Add_{son}=Add_{son}+Add_x\)\(O(n\log n)\)

只有區間查詢歷史最小值/最大值:

  將最值和非最值分開維護,則操作全部轉化為區間加減操作,可以沿用 \(Pre\) 懶標記進行維護,\(O(n\log^2n)\)

無區間 \(\min,\max\) 操作:

  • 區間歷史最大值 & 歷史最大值之和:記 \(C_i=A_i-B_i\),區間加轉化為 \(C_i\rightarrow \min(C_i+x,0)\)\(O(n\log^2n)\)
  • 區間歷史版本之和:令 \(t\) 為當前已結束的操作數,記 \(C_i=B_i+t\cdot A_i\),區間加轉化為 \(C_i\rightarrow C_i-x\cdot t\)\(O(n\log n)\)

有區間 \(\min,\max\) 操作:

  將最值和非最值分開維護 \(C_i\),則轉化為上面的「無區間 \(\min,\max\) 操作」問題,在分開維護部分會多出 \(1\)\(\log\),從而「區間歷史最大值」和「歷史最大值之和」 \(O(n\log^3n)\),「區間歷史版本之和」 \(O(n\log^2n)\)


 

一些簡單的例題&實現

洛谷P6242 【模板】線段樹 3

即前面區間最值中的「無區間 \(\min,\max\) 操作」,線段樹節點維護 \(A_i\) 最大值、\(B_i\) 最大值、\(A_i\) 嚴格次小值、\(A_i\) 最大值個數、\(A_i\) 最大值和非最大值的加法標記,\(B_i\) 最大值和非最大值的歷史加法標記(即 \(Pre\)),然后就可以了。

#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 505000
#define ll long long
#define Inf 0x7fffffff
using namespace std;

int n, m, a[N];

struct SegmentTreeBeats{
    struct node{
        int mx_a, smx, cnt, mx_b;
        ll sum, add_a1, add_a2, add_b1, add_b2;
        // add_a1, add_a2: lazy add tag   add_b1, add_b2: historical tag
    } t[N<<2];

    void pushup(int x){
        t[x].mx_a = max(t[x*2].mx_a, t[x*2+1].mx_a);
        t[x].mx_b = max(t[x*2].mx_b, t[x*2+1].mx_b);
        t[x].sum = t[x*2].sum + t[x*2+1].sum;
        if(t[x*2].mx_a == t[x*2+1].mx_a) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
        else if(t[x*2].mx_a > t[x*2+1].mx_a) t[x].smx = max(t[x*2].smx, t[x*2+1].mx_a);
        else t[x].smx = max(t[x*2].mx_a, t[x*2+1].smx);
        t[x].cnt = (t[x*2].mx_a >= t[x*2+1].mx_a) * t[x*2].cnt + (t[x*2+1].mx_a >= t[x*2].mx_a) * t[x*2+1].cnt;
    }

    // a: to max, b: to historical max, c: to non max, d: to historical non max
    void addtag(int x, int l, int r, ll a, ll b, ll c, ll d){
        t[x].sum += a*t[x].cnt + c*(r-l+1-t[x].cnt);
        t[x].mx_b = max((ll)t[x].mx_b, t[x].mx_a + b);
        t[x].add_b1 = max(t[x].add_b1, t[x].add_a1 + b);
        t[x].add_b2 = max(t[x].add_b2, t[x].add_a2 + d);
        t[x].mx_a += a, t[x].add_a1 += a, t[x].add_a2 += c;
        if(t[x].smx > -Inf) t[x].smx += c;
    }

    void pushdown(int x, int l, int r){
        int mid = (l+r)>>1, mx = max(t[x*2].mx_a, t[x*2+1].mx_a);
        ll add_a1 = t[x].add_a1, add_a2 = t[x].add_a2, add_b1 = t[x].add_b1, add_b2 = t[x].add_b2;
        if(t[x*2].mx_a == mx) addtag(x*2, l, mid, add_a1, add_b1, add_a2, add_b2);
        else addtag(x*2, l, mid, add_a2, add_b2, add_a2, add_b2);
        if(t[x*2+1].mx_a == mx) addtag(x*2+1, mid+1, r, add_a1, add_b1, add_a2, add_b2);
        else addtag(x*2+1, mid+1, r, add_a2, add_b2, add_a2, add_b2);
        t[x].add_a1 = t[x].add_b1 = t[x].add_a2 = t[x].add_b2 = 0;
    }

    void build(int x, int l, int r){
        if(l == r){
            t[x].sum = t[x].mx_a = t[x].mx_b = a[l];
            t[x].smx = -Inf, t[x].cnt = 1;
            return;
        }
        int mid = (l+r)>>1;
        build(x*2, l, mid), build(x*2+1, mid+1, r);
        pushup(x);
    }

    void update_add(int x, int l, int r, int L, int R, int k){
        if(l >= L && r <= R){ addtag(x, l, r, k, k, k, k); return; }
        int mid = (l+r)>>1;
        pushdown(x, l, r);
        if(mid >= L) update_add(x*2, l, mid, L, R, k);
        if(mid < R) update_add(x*2+1, mid+1, r, L, R, k);
        pushup(x);
    }

    void update_min(int x, int l, int r, int L, int R, int val){
        if(t[x].mx_a <= val) return;
        if(l >= L && r <= R && t[x].smx < val){ addtag(x, l, r, val-t[x].mx_a, val-t[x].mx_a, 0, 0); return; }
        int mid = (l+r)>>1;
        pushdown(x, l, r);
        if(mid >= L) update_min(x*2, l, mid, L, R, val);
        if(mid < R) update_min(x*2+1, mid+1, r, L, R, val);
        pushup(x);
    }

    ll query(int x, int l, int r, int L, int R, int id){
        if(l >= L && r <= R) return id == 1 ? t[x].sum : (id == 2 ? t[x].mx_a : t[x].mx_b); 
        int mid = (l+r)>>1;
        pushdown(x, l, r);
        ll a = (mid >= L) ? query(x*2, l, mid, L, R, id) : (id == 1 ? 0 : -Inf);
        ll b = (mid < R) ? query(x*2+1, mid+1, r, L, R, id) : (id == 1 ? 0 : -Inf);
        return (id == 1 ? a+b : max(a, b));
    }
} T;

int main(){
    ios::sync_with_stdio(false);
    cin>>n>>m;
    rep(i,1,n) cin>>a[i];
    T.build(1, 1, n);
    int type, l, r, k;
    while(m--){
        cin>>type>>l>>r;
        switch(type){
            case 1 : cin>>k, T.update_add(1, 1, n, l, r, k); break;
            case 2 : cin>>k, T.update_min(1, 1, n, l, r, k); break;
            default : cout<< T.query(1, 1, n, l, r, type-2) <<endl;
        }
    }
    return 0;
}

 

Codeforces 1290E Cartesian Tree

給定一個長為 \(n\) 的排列的 \(a_i\),對於每個 \(k\in[1,n]\),以前 \(k\) 小的值(保持排列中的順序)建笛卡爾樹,求出所有子樹大小之和。\(1\leq n\leq 150000\)

\(l_i,r_i\) 分別為 \(a_i\) 從左右得到的最后一個比其小的數的位置(保留前 \(k\) 小的值組成序列的位置),則節點 \(i\) 的子樹大小即 \(r_i-l_i+1\),於是答案即為 \(\sum r_i-\sum l_i+k\)。分開維護 \(l_i\)\(r_i\),以 \(r_i\) 為例,容易發現當 \(k\) 增加 \(1\)\(k+1\) 插入到序列中時,\(pos_{k+1}\) 左側的 \(r_i\rightarrow \min(r_i,pos_{k+1})\),而右側的 \(r_i\) 因為左邊添加了元素,\(r_i\rightarrow r_i+1\)\(l_i\) 也是類似的處理。所以我們使用 Segment Tree Beats 維護區間 \(\min\)、區間加、全局和即可,\(O(n\log^2n)\)

#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 160000
#define Inf 0x3f3f3f3f
#define ll long long
#define lowbit(x) (x&-x)
using namespace std; 

int n;
int a[N], pos[N];

struct Segment_Tree_Beats{
    struct node{
	int mx, smx, cnt, num, tag = -1, lazy;
	ll sum;
    } t[N<<2];

    void pushup(int x){
	t[x].mx = max(t[x*2].mx, t[x*2+1].mx);
	t[x].num = t[x*2].num + t[x*2+1].num;
	t[x].sum = t[x*2].sum + t[x*2+1].sum;
	if(t[x*2].mx == t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
	else if(t[x*2].mx > t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].mx);
	else t[x].smx = max(t[x*2].mx, t[x*2+1].smx);
	t[x].cnt = (t[x*2].mx >= t[x*2+1].mx) * t[x*2].cnt + (t[x*2+1].mx >= t[x*2].mx) * t[x*2+1].cnt;
    }

    void pushadd(int x, int k){
	t[x].sum += (ll)k * t[x].num;
	if(t[x].num){
	    t[x].mx += k, t[x].lazy += k;
	    if(t[x].smx > 0) t[x].smx += k;
	    if(~t[x].tag) t[x].tag += k;
	}
    }

    void pushmin(int x, int k){
	if(t[x].mx <= k) return;
	t[x].sum -= (ll)t[x].cnt * (t[x].mx - k);
	t[x].mx = k, t[x].tag = k;
    }

    void pushdown(int x){
	if(t[x].lazy) pushadd(x*2, t[x].lazy), pushadd(x*2+1, t[x].lazy);
	if(~t[x].tag) pushmin(x*2, t[x].tag), pushmin(x*2+1, t[x].tag);
	t[x].tag = -1, t[x].lazy = 0;
    }

    void insert(int x, int l, int r, int pos, int val){
	if(l == r){ t[x].sum = t[x].mx = val, t[x].cnt = t[x].num = 1; return; }
	int mid = (l+r)>>1;
	pushdown(x);
	if(mid >= pos) insert(x*2, l, mid, pos, val);
	else insert(x*2+1, mid+1, r, pos, val);
	pushup(x);
    }

    void update(int x, int l, int r, int L, int R, int k, int id){
	if(L > R) return;
	if(id && t[x].mx <= k) return;
	if(l >= L && r <= R){
	    if(id && t[x].smx < k){ pushmin(x, k); return; }
	    else if(!id){ pushadd(x, k); return; }
	}
	int mid = (l+r)>>1;
	pushdown(x);
	if(mid >= L) update(x*2, l, mid, L, R, k, id);
	if(mid < R) update(x*2+1, mid+1, r, L, R, k, id);
	pushup(x);
    }
} LB, RB;

struct Fenwick{
    int t[N];
    void update(int pos, int k){
	while(pos <= n) t[pos] += k, pos += lowbit(pos);
    }
    int get(int pos){
	int ret = 0;
	while(pos) ret += t[pos], pos -= lowbit(pos);
	return ret;
    }
} T;

int main(){
    ios::sync_with_stdio(false);
    cin>>n;
    rep(i,1,n) cin>>a[i], pos[a[i]] = i;

    rep(i,1,n){
	RB.update(1, 1, n, pos[i]+1, n, 1, 0);
	RB.update(1, 1, n, 1, pos[i]-1, T.get(pos[i]), 1);
	LB.update(1, 1, n, 1, n-pos[i]+1, -1, 0);
	LB.update(1, 1, n, 1, n-pos[i]+1, n-T.get(pos[i])-1, 1);
	RB.insert(1, 1, n, pos[i], i), LB.insert(1, 1, n, n-pos[i]+1, n);
	cout<< RB.t[1].sum - ((ll)i*(n+1) - LB.t[1].sum) + i <<endl;
	T.update(pos[i], 1);
    }
    return 0;
}

 

Codeforces 1572F Stations

\(n\) 個城市,每個城市有兩個屬性 \(h_i,w_i\),第 \(i\) 個城市的廣播可以覆蓋到所有 \([i,w_i]\) 中滿足 \(\max_{i<k\leq j}\{h_k\}<h_i\) 的城市 \(j\)

初始時所有 \(h_i=0,w_i=i\),每次操作可以使得城市 \(c_i\)\(h\) 成為全局嚴格最大值,並修改 \(w_{c_i}\),或者詢問 \([l,r]\) 中,對於每個城市來說廣播可以覆蓋到它的城市個數之和。

\(1\leq n\leq 2\times 10^5\)

顯然每個城市覆蓋的區域是一個從 \(i\) 開始的區間,而維護區間右端點是一個「 單點修改,區間取 \(\min\)」問題,另外用一棵線段樹維護區間覆蓋,注意到 Segment Tree Beats 里值的修改是直接進行的,所以修改時可以順帶在另一棵線段樹上進行區間修改,即打 \(tag\) 時維護即可。詢問時直接在線段樹查詢區間和。時間復雜度 \(O(n\log^2n)\)


#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 200021
#define ll long long
using namespace std;

int n, q;

struct SegmentTree{
    struct node{
        ll sum, lazy;
    } t[N<<2];

    void pushdown(int x, int l, int r){
        int mid = (l+r)>>1;
        t[x*2].sum += (mid-l+1) * t[x].lazy, t[x*2].lazy += t[x].lazy;
        t[x*2+1].sum += (r-mid) * t[x].lazy, t[x*2+1].lazy += t[x].lazy;
        t[x].lazy = 0;
    }

    void update(int x, int l, int r, int L, int R, ll k){
        if(L > R) return;
        if(l >= L && r <= R){ t[x].sum += (r-l+1) * k, t[x].lazy += k; return; }
        int mid = (l+r)>>1;
        pushdown(x, l, r);
        if(mid >= L) update(x*2, l, mid, L, R, k);
        if(mid < R) update(x*2+1, mid+1, r, L, R, k);
        t[x].sum = t[x*2].sum + t[x*2+1].sum;
    }

    ll get(int x, int l, int r, int L, int R){
        if(l >= L && r <= R) return t[x].sum;
        int mid = (l+r)>>1; ll ret = 0;
        pushdown(x, l, r);
        if(mid >= L) ret += get(x*2, l, mid, L, R);
        if(mid < R) ret += get(x*2+1, mid+1, r, L, R);
        return ret;
    }
} BIT;

struct Segment_Tree_Beats{
    struct node{
        int mx, smx, cnt, tag = -1;
    } t[N<<2];

    void pushup(int x){
        t[x].mx = max(t[x*2].mx, t[x*2+1].mx);
        if(t[x*2].mx == t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
        else if(t[x*2].mx > t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].mx);
        else t[x].smx = max(t[x*2].mx, t[x*2+1].smx);
        t[x].cnt = (t[x*2].mx >= t[x*2+1].mx) * t[x*2].cnt + (t[x*2+1].mx >= t[x*2].mx) * t[x*2+1].cnt;
    }

    void addtag(int x, int k, bool frt){
        if(t[x].mx <= k) return;
        if(frt) BIT.update(1, 1, n, k+1, t[x].mx, -t[x].cnt);
        t[x].mx = t[x].tag = k;
    }

    void pushdown(int x){
        if(~t[x].tag)
            addtag(x*2, t[x].tag, 0), addtag(x*2+1, t[x].tag, 0);
        t[x].tag = -1;
    }

    void insert(int x, int l, int r, int pos, int val){
        if(l == r){ 
            BIT.update(1, 1, n, l, t[x].mx, -t[x].cnt), BIT.update(1, 1, n, l, val, 1);
            t[x].mx = val, t[x].cnt = 1; return; 
        }
        int mid = (l+r)>>1;
        pushdown(x);
        if(mid >= pos) insert(x*2, l, mid, pos, val);
        else insert(x*2+1, mid+1, r, pos, val);
        pushup(x);
    }

    void update(int x, int l, int r, int L, int R, int k){
        if(t[x].mx <= k || L > R) return;
        if(l >= L && r <= R && t[x].smx < k){ addtag(x, k, 1); return; }
        int mid = (l+r)>>1;
        pushdown(x);
        if(mid >= L) update(x*2, l, mid, L, R, k);
        if(mid < R) update(x*2+1, mid+1, r, L, R, k);
        pushup(x);
    }
} T;

int main(){
    ios::sync_with_stdio(false);
    cin>>n>>q;
    rep(i,1,n) T.insert(1, 1, n, i, i);
    int type, c, g, l, r;
    while(q--){
        cin>>type;
        if(type == 1){
            cin>>c>>g;
            T.insert(1, 1, n, c, g), T.update(1, 1, n, 1, c-1, c-1);
        } else cin>>l>>r, cout<< BIT.get(1, 1, n, l, r) <<endl;
    }
    return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM