關於線段樹的感悟(Segment Tree)


線段樹的感悟 : 學過的東西一定要多回頭看看,不然真的會忘個干干凈凈。

線段樹的 Introduction :

English Name : Segment Tree
顧名思義 : 該數據結構由兩個重要的東西組成 : 線段,樹,連起來就是在樹上的線段。
想一下,線段有啥特征 ?
不就是兩個端點中間一條線嗎,哈哈,也可以這么理解,但這樣是不是稍微難聽呀,所以
我們用一個華麗的詞語來描述這個兩點之間的一條線,這個詞語就是不知道哪個先知發
明的,就是 -- 區間。
所以我們就可猜想到,所以線段樹一定是用來處理區間問題的。

線段樹長個啥樣子?

展示一個區間  1 - 10 的一顆線段樹,就是這么個樹東西。

線段樹的基本結構 :

1、線段樹的每個節點都代表一個區間
2、線段樹具有唯一的根節點,代表的區間的整個統計范圍,[1,N]
3、線段樹的每個葉節點都代表一個長度為 1 的元區間 [x,x],也就是我們原數組中每個值,原數組中有幾個值
   就有多少個葉子節點(可以參照上圖了解一下)。
4、對於每個內部節點 [l,r],它的左子節點是 [l,mid],右子節點是 [mid + 1,r],mid = l + r >> 1(向下取整)

線段樹經常處理那些區間問題 ?

1、單點查詢(查詢某個位置上的值是多少)
2、單點修改(修改某個位置上的值)
3、區間查詢(查詢某個區間的 和、最大值、最小值、最大公約數、and so on)
4、區間修改(修改某個區間的值, eg:讓某個區間都 + 一個數、and so on)

線段樹需要注意的地方 :

1、結構體空間一定要開 4 倍,一定要記得看 4 倍(看上面這棵樹,按節點編號我們可以看到一共有 25 個節點,但算上空余的位置呢?)
   會發現有 31 個節點,可以自己數一下,所以我們要開原數組的 4 倍,避免出現數組越界,非法訪問的情況(段錯誤)。
2、區間判斷的時候一定不要寫反(下面寫的時候就知道了,這個坑讓我 Debug 了一個多小時)
3、沒事多打打,模板,就當練手速了。

線段樹的基本操作 :

1、Struct結構體存儲

struct node {
	LL l,r;
	LL sum;  // 看需要向父節點傳送什么
} tr[maxn << 2];

2、 Build

void pushup(LL u) {
	tr[u].sum = gcd(tr[u << 1].sum,tr[u << 1 | 1].sum);
	return ;
}

void build(LL u,LL l,LL r) {
	tr[u].l = l,tr[u].r = r;  // 初始化(節點 u 代表區間 [l,r])
	if(l == r) {
		tr[u].sum = b[l]; // 遞歸到葉節點賦初值
		return ;
	}
	LL mid = l + r >> 1;      // 折半
	build(u << 1,l,mid);      // 向左子節點遞歸
	build(u << 1 | 1,mid + 1,r); // 向右子節點遞歸
	pushup(u);                // 從下往上傳遞信息
	return;
}

3、Update

void update(LL u,LL x,LL v) {
	if(tr[u].l == tr[u].r) {        // 找到葉節點
		tr[u].sum += v;         // 在某個位置加上一個數
		return ;
	}
	LL mid = tr[u].l + tr[u].r >> 1;
	if(x <= mid) update(u << 1,x,v); // x 屬於左半區間
	else update(u << 1 | 1,x,v);     // x 屬於右半區間
	pushup(u);                       // 從下向上更新信息
	return ;
}

4、Query :

1、若 [l,r] 完全覆蓋了當前節點代表的區間,則立即回溯。
2、若左子節點與 [l,r] 有重疊部分,則遞歸訪問左子節點。
3、若右子節點與 [l,r] 有重疊部分,則遞歸訪問右子節點。
LL query(int u,int l,int r) {
    if(tr[u].l >= l && tr[u].r <= r) {   // 完全包含
        return tr[u].sum;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    LL sum = 0;
    if(l <= mid) sum += query(u << 1,l,r);
    if(r > mid) sum += query(u << 1 | 1,l,r);
    return sum; 
}

上述就是線段樹的基本操作,基本上都是圍繞單點問題進行操作,如果要涉及到復雜的區間操作,
例如 : 給區間 [l,r] 每個數都 + d
這時如果還用上述操作,我們就需要進行 l - r + 1 次操作,如果有多次這樣的操作,顯然時間
復雜度會很高,這時候我們應該選擇什么樣的方法來降低時間復雜度呢 ?

Lazy(懶) 標記應運而生

簡單一點來說就是,減少重復的操作,如果說我們操作的每一個數都在一個區間范圍內,那么
我們就可以直接處理這個區間,不需要再一個一個處理,比如上面的給區間的每一個數 + d;
假設說我們已經知道 [l,r] 完全包含一個區間 [x,y],也就是說 區間[x,y]是 [l,r]的
一個子區間,那么這個時候我們是不是直接可以計算出 [x,y] 這個區間 都 + d 后的值是
多少, (x - y + 1) * d(假設是求和的話),這樣我們就可以不再用去一個一個加,然后
再合並了,我們知道有這樣的區間后,怎么用呢?這時候就需要進行標記一下,便於我們知道
這個地方有一個區間可以直接處理,不需要再麻煩着向下繼續去處理了,是不是很懶,哈哈。
/*
    懶標記的含義 : 該節點曾經被修改,但其子節點尚未被更新。
    在后續的指令中,我們需要從某個節點向下遞歸時,檢查該節點是否具有標記,若有標記,就根據
    標記信息更新 該節點 的兩個子節點,同時為該節點的兩個子節點增加標記,然后清楚 p 的標記。
*/
void pushdown(int u) {
    if(tr[u].lazy) {    // 節點 u 有標記
        tr[u << 1].sum += tr[u].lazy * (tr[u << 1].r - tr[u << 1].l + 1); // 更新左子節點信息
        tr[u << 1| 1].sum += tr[u].lazy * (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1); // 更新右子節點
        tr[u << 1].lazy += tr[u].lazy;     // 給左子節點打延遲標記
        tr[u << 1 | 1].lazy += tr[u].lazy; // 給右子節點打延遲標記
        tr[u].lazy = 0;                    // 清楚父節點的延遲標記(這點很重要)
    }
    return ;
}

加上 Lazy 標記的其他操作 :

// Build 不變
// Update
void modify(int u,int l,int r,int x) {
    if(tr[u].l >= l && tr[u].r <= r) {  // 完全覆蓋
        tr[u].sum += (tr[u].r- tr[u].l + 1) * x; // 更新節點信息
        tr[u].lazy += x;                // 給節點打延遲標記
        return ;
    }
    pushdown(u);                        // 下傳延遲標記
    int mid = tr[u].l + tr[u].r >> 1;
    if(l <= mid) modify(u << 1,l,r,x);
    if(r > mid) modify(u << 1 | 1,l,r,x);
    pushup(u);
    return ;
}

// Query
LL query(int u,int l,int r) {
    if(tr[u].l >= l && tr[u].r <= r) {
        return tr[u].sum;
    }
    pushdown(u);                  // 同上
    int mid = tr[u].l + tr[u].r >> 1;
    LL sum = 0;
    if(l <= mid) sum += query(u << 1,l,r);
    if(r > mid) sum += query(u << 1 | 1,l,r);
    return sum; 
}

總結 :

線段樹的操作基本上就這些,哈哈,實際上自己就了解這么多,而且是最近有幾場比賽碰見挺多的,就學了一下,
主要是手得多動動,有時候考察得還是比較復雜得,先把這些基礎得模板搞懂吧。

例題(模板題):

1、 一個簡單的整數問題

#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 1e5 + 10;
typedef long long LL;

struct node {
    int l,r;
    LL sum,lazy;
}tr[maxn << 2];
int a[maxn];
int n,m;
int l,r;

int main(void) {
    void build(int u,int l,int r);
    void modify(int u,int l,int r,int x);
    LL query(int u,int l,int r);
    scanf("%d%d",&n,&m);
    for(int i = 1; i <= n; i ++) {
        scanf("%d",&a[i]);
    }
    build(1,1,n);
    while(m --) {
        char ch;
        cin >> ch;
        if(ch == 'Q') {
            scanf("%d",&l); 
            printf("%lld\n",query(1,1,l) - query(1,1,l - 1));
        } else {
            int value;
            scanf("%d%d%d",&l,&r,&value);
            modify(1,l,r,value);
        }
    }
    return 0;
} 

void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
    return ;
}

void pushdown(int u) {
    if(tr[u].lazy) {
        tr[u << 1].sum += tr[u].lazy * (tr[u << 1].r - tr[u << 1].l + 1);
        tr[u << 1| 1].sum += tr[u].lazy * (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1);
        tr[u << 1].lazy += tr[u].lazy;
        tr[u << 1 | 1].lazy += tr[u].lazy;
        tr[u].lazy = 0;
    }
    return ;
}

void build(int u,int l,int r) {
    tr[u].l = l,tr[u].r = r;
    if(l == r) {
        tr[u].sum = a[l];
        return ;
    }
    int mid = l + r >> 1;
    build(u << 1,l,mid);
    build(u << 1 | 1,mid + 1,r);
    pushup(u);
    return ;
}

void modify(int u,int l,int r,int x) {
    if(tr[u].l >= l && tr[u].r <= r) {
        tr[u].sum += (tr[u].r- tr[u].l + 1) * x;
        tr[u].lazy += x;
        return ;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l <= mid) modify(u << 1,l,r,x);
    if(r > mid) modify(u << 1 | 1,l,r,x);
    pushup(u);
    return ;
}

LL query(int u,int l,int r) {
    if(tr[u].l >= l && tr[u].r <= r) {
        return tr[u].sum;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    LL sum = 0;
    if(l <= mid) sum += query(u << 1,l,r);
    if(r > mid) sum += query(u << 1 | 1,l,r);
    return sum; 
}

2、一個簡單的整數問題2

#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 1e5 + 10;
typedef long long LL;

struct node {
    int l,r;
    LL sum,lazy;
}tr[maxn << 2];
int a[maxn];
int n,m;
int l,r;

int main(void) {
    void build(int u,int l,int r);
    void modify(int u,int l,int r,int x);
    LL query(int u,int l,int r);
    scanf("%d%d",&n,&m);
    for(int i = 1; i <= n; i ++) {
        scanf("%d",&a[i]);
    }
    build(1,1,n);
    while(m --) {
        char ch;
        cin >> ch;
        if(ch == 'Q') {
            scanf("%d%d",&l,&r);    
            printf("%lld\n",query(1,l,r) );
        } else {
            int value;
            scanf("%d%d%d",&l,&r,&value);
            modify(1,l,r,value);
        }
    }
    return 0;
} 

void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
    return ;
}

void pushdown(int u) {
    if(tr[u].lazy) {
        tr[u << 1].sum += tr[u].lazy * (tr[u << 1].r - tr[u << 1].l + 1);
        tr[u << 1| 1].sum += tr[u].lazy * (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1);
        tr[u << 1].lazy += tr[u].lazy;
        tr[u << 1 | 1].lazy += tr[u].lazy;
        tr[u].lazy = 0;
    }
    return ;
}

void build(int u,int l,int r) {
    tr[u].l = l,tr[u].r = r;
    if(l == r) {
        tr[u].sum = a[l];
        return ;
    }
    int mid = l + r >> 1;
    build(u << 1,l,mid);
    build(u << 1 | 1,mid + 1,r);
    pushup(u);
    return ;
}

void modify(int u,int l,int r,int x) {
    if(tr[u].l >= l && tr[u].r <= r) {
        tr[u].sum += (tr[u].r- tr[u].l + 1) * x;
        tr[u].lazy += x;
        return ;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l <= mid) modify(u << 1,l,r,x);
    if(r > mid) modify(u << 1 | 1,l,r,x);
    pushup(u);
    return ;
}

LL query(int u,int l,int r) {
    if(tr[u].l >= l && tr[u].r <= r) {
        return tr[u].sum;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    LL sum = 0;
    if(l <= mid) sum += query(u << 1,l,r);
    if(r > mid) sum += query(u << 1 | 1,l,r);
    return sum; 
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM