有趣的 zkw 線段樹(超全詳解)


zkw segment-tree 真是太棒了(真的重口味)!寫篇博客紀念入門

 

emmm...首先我們來介紹一下 zkw 線段樹這個東西(俗稱 "重口味" ,與 KMP 類似,咳咳...



zkw 線段樹的介紹

其實 zkw 線段樹和普通線段樹區別沒多大(區別可大了去了!)

emmm...起碼它們的思想是一致的,都是節點維護區間信息嘛。

只不過...普通線段樹的維護和查詢是遞歸式,而 zkw線段樹是循環式的...

但是不要以為 zkw線段樹只是靠循環加速上位的!

zkw線段樹能支持非常多強(luan)如(qi)閃(ba)電(zao)的操作(最后例題講)。

 

 

 

 zkw 線段樹 與普通線段樹 的比較

 

emmm...這里你看着 普通線段樹 的節點比 zkw線段樹 的小對吧,但其實兩者差不多,(因為線段樹是要開4倍空間的啊,這里只是沒有畫出用不到的節點罷了),

 

 

zkw 線段樹的形態

其實上圖...還是無法體現zkw 線段樹的具體形態的,(但是相信聰明的你一定看懂了所以我就不講了)

emmm...於是乎還是上圖解釋一切

 

 

zkw 線段樹的建立

 

首先你要寫個循環,讓 m 這個值(也就是非葉子節點)大於 n (也就是總葉子結點數),以此保證 這棵樹的葉子 能夠容納你要維護的 n 個值

然后你要從 m 倒推 到 1 號節點(注意是 m 倒推回 1 ,保證維護每個節點時該節點的孩子都已經被維護完畢讓每個節點維護它左右孩子的信息

 

 

代碼實現

這里我們假設要維護的信息有:區間和,區間最小值,區間最大值 。 下同

inline void build(int n){
//  維護這么多信息都只需要這么幾行,可見維護信息單一時代碼應該會短的不像話(壓行過的話大概三四行)
    for(m=1;m<=n;m<<=1);
    for(int i=m+1;i<=m+n;++i)
        sum[i]=mn[i]=mx[i]=read();
    for(int i=m-1;i;--i)
        sum[i]=a[i<<1]+a[i<<1|1],
        mn[i]=min(mn[i<<1],mn[i<<1|1]),
        mx[i]=max(mx[i<<1],mx[i<<1|1]);
} 


但是,這里對 mn 和 mx 的處理是在無修改操作的基礎上實行的,所以這樣寫並不支持修改操作。

 

那么我們可以這樣寫:

inline void build(){
    for(m=1;m<=n;m<<=1);
    for(int i=m+1;i<=m+n;++i)
        sum[i]=mn[i]=mx[i]=read();
    for(int i=m-1;i;--i){
        sum[i]=sum[i<<1]+sum[i<<1|1];
        
        mn[i]=min(mn[i<<1],mn[i<<1|1]),
        mn[i<<1]-=mn[i],mn[i<<1|1]-=mn[i];
        
        mx[i]=max(mx[i<<1],mx[i<<1|1]),
        mx[i<<1]-=mx[i],mx[i<<1|1]-=mx[i];
    }
}

PS:以下的操作(單點、區間更新,單點、區間查詢)所附的代碼,都基於可修改的版本

 

 

zkw 線段樹的更新

單點更新

這個單點更新還是比較好解決的,你只要找到更新的節點所在的葉子結點,然后修改后一直向父節點更新即可。

(這個。。。就不用上圖了吧...你腦補一下就差不多了)

 

代碼實現

這里我們假設將一個節點的值增加 v (修改的話...就記錄一下原數組,然后算差值就好了吧?)

inline void update_node(int x,int v,int A=0){
    x+=m,mx[x]+=v,mn[x]+=v;for(;x>1;x>>=1){
        sum[x]+=v;
        A=min(mn[x],mn[x^1]);
        mn[x]-=A,mn[x^1]-=A,mn[x>>1]+=A;
        A=max(mx[x],mx[x^1]),
        mx[x]-=A,mx[x^1]-=A,mx[x>>1]+=A;
    }
}

 

 

 

區間更新

這個東西...有點麻煩(你得稍微感性理解)。

就是說...你每次要更新一段區間的時候,你要讓左端點 -1 ,右端點 +1 。

然后你在更新權值的時候要判斷 左端點當前所處的節點是否是它父節點的左孩子,

是的話就讓該節點的兄弟(也就是它父節點的右孩子)得到更新,否則不做處理,

然后左節點再向右移一位(也就是跳到了父節點),重復迭代以上步驟。

那么右端點呢?其實也就是和左端點反着來了而已。

還有一點,循環的終止條件?這個簡單,就是當左右端點所處的節點是兄弟節點的時候結束循環。

類似的,你更新一個節點時 同樣可以用這種方法維護(只不過這樣就更麻煩了啊)。

這樣我們可以看到要被更新的區間都已經被染成黃色了。但是,zkw 沒有下傳標記啊!

那么我們查詢的區間如果在染成黃色的節點的下部(也就是黃色節點的子樹內)該怎么辦?

我們可以這樣...這樣...沒錯!標記永久化!

因為我們已經將一個節點的標記永久化了,那么在該節點被訪問到的時候,只要將當前查詢到的、包含在該節點所管轄區間范圍內的  區間長度乘上標記值,累加入答案即可。

(具體實現得看代碼)

 

 

區間更新的特殊情況

 

同學們有沒有注意到一種區間查詢的特殊情況?沒錯,就是右區間+1后到達下一層的特殊情況

就以上圖為例,假設維護區間為 1 ~ 7 ,現在對 2 ~ 7 進行區間加操作,那么  t = 7+1 = 8 ,於是 t 就到達了不存在的第 5 層!

現在你想的一定是這種情況該怎么避免這種情況(其實很簡單,你在建樹確定 m 的值的時候,將判斷條件改成 " m<=n+1 " 就行了)

但我現在要證明這種情況不需要避免也不會出問題(基本上...吧?)

 

 

 

我們可以看到,s 和 t 在跳到 0 和 1 時滿足了終止條件,並且需要更新的節點都得到了更新,而且,其實 t 就沒有更新過節點...

 

 

代碼實現

這里我們假設要將一段區間的每個數加上 v ,然后維護的信息同上

inline void update_part(int s,int t,int v){
    int A=0,lc=0,rc=0,len=1; 
    for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){ //在這里的 add 就是標記數組了
        if(s&1^1) add[s^1]+=v,lc+=len, mn[s^1]+=v,mx[s^1]+=v;
        if(t&1)    add[t^1]+=v,rc+=len, mn[t^1]+=v,mx[t^1]+=v;
        
        sum[s>>1]+=v*lc, sum[t>>1]+=v*rc;
        
        A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A,
        A=min(mn[t],mn[t^1]),mn[t]-=A,mn[t^1]-=A,mn[t>>1]+=A;
        
        A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A,
        A=max(mx[t],mx[t^1]),mx[t]-=A,mx[t^1]-=A,mx[t>>1]+=A;
    }
    for(lc+=rc;s>1;s>>=1){
        sum[s>>1]+=v*lc;
        A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A,
        A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A;
    }
}

這里的 lc 和 rc 的所代表的含義需要講一下

lc 代表左端點所處的節點下有多少長度的區間在更新區間內, rc 同理 ,通俗一點地說,就是 s 和 t 所分別走過的節點中包含的更新過的區間的總長

 

 

 

 

zkw線段樹的查詢

 

單點查詢

這個沒什么好說的吧,你從葉子結點一直跳父節點,把途中節點的 mn (或者 mx )權值累加,最后得到的就是答案

 

代碼實現

inline int query_node(int x,int ans=0){
    for(x+=m;x;x>>=1) ans+=mn[s];
    return ans;
}

 

 

 

區間查詢

什么?zkw線段樹的區間查詢?我不會啊。     

那么這里的區間查詢...其實有點難說啊!要不就直接上代碼得了?咳咳...

這個其實和上面的區間更新的思路差不多,可能要講的就是標記累加的問題了吧。

那么 lc 和 rc 之前已經講過了,就是 s 節點和 t 節點分別走過的節點中所包含的更新區間的長度。

那么 add 這個數組啊...啊...啊...這個數組啊,它...要不我們直接看代碼吧?

它好在哪里啊?好難說啊...其實它就是記錄了你每次大塊累加區間時的副產品啊,類似於線段樹的懶標記。

但是和普通線段樹不一樣的是,線段樹的查詢是自上而下查詢(順便釋放標記)然后又自下而上的遞歸回去的,

而 zkw 的查詢是直接自下而上的,於是它無法釋放標記,於是它就在遇到某個打過懶標記的節點時,將當前查詢到的區間長度乘上標記值,累加入答案。

(所以這還是懶標記啊!不上圖了自行腦補。emmm...算了吧那還是上一張圖好了

 

代碼實現

inline int query_sum(int s,int t){
    int lc=0,rc=0,len=1,ans=0;
    for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){
        if(s&1^1) ans+=sum[s^1]+len*add[s^1],lc+=len;
        if(t&1) ans+=sum[t^1]+len*add[t^1],rc+=len;
        
        if(add[s>>1]) ans+=add[s>>1]*lc;
        if(add[t>>1]) ans+=add[t>>1]*rc; 
    }
    for(lc+=rc,s>>=1;s;s>>=1)
        if(add[s]) ans+=add[s]*lc;
    return ans;
}

inline int query_min(int s,int t,int L=0,int R=0,int ans=0){
    if(s==t) return query_node(s);  // 單點要特判, 下同
    for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){ // 這里 s 和 t 直接加上 m
        L+=mn[s],R+=mn[t];
        if(s&1^1) L=min(L,mn[s^1]);
        if(t&1) R=min(R,mn[t^1]);
    }
    for(ans=min(L,R),s>>=1;s;s>>=1) ans+=mn[s];
    return ans;
}

inline int query_max(int s,int t,int L=0,int R=0,int ans=0){
    if(s==t) return query_node(s);
    for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){
        L+=mx[s],R+=mx[t];
        if(s&1^1) L=max(L,mx[s^1]);
        if(t&1) R=max(R,mx[t^1]);
    }
    for(ans=max(L,R),s>>=1;s;s>>=1) ans+=mx[s];
    return ans;
}

這里詢問時 s 和 t 不能 -1 或 +1 ,不然會查詢到旁邊不相干的節點。

然后 s == t 的情況要特判一下,防止 s 和 t 一直都不是兄弟,陷入死循環。

 

 

 

zkw 的代碼實現(模板)

 

完全代碼

 

  1 //by Judge
  2 #include<cstdio>
  3 #include<iostream>
  4 using namespace std;
  5 const int M=1e5+111;
  6 //#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
  7 char buf[1<<21],*p1=buf,*p2=buf;
  8 inline int read(){
  9     int x=0,f=1; char c=getchar();
 10     for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
 11     for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f;
 12 }
 13 char sr[1<<21],z[20];int C=-1,Z;
 14 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
 15 inline void print(int x){
 16     if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
 17     while(z[++Z]=x%10+48,x/=10);
 18     while(sr[++C]=z[Z],--Z);sr[++C]='\n';
 19 }
 20 int n,m,q;
 21 int sum[M<<2],mn[M<<2],mx[M<<2],add[M<<2];
 22 inline void build(){
 23     for(m=1;m<=n;m<<=1);
 24     for(int i=m+1;i<=m+n;++i)
 25         sum[i]=mn[i]=mx[i]=read();
 26     for(int i=m-1;i;--i){
 27         sum[i]=sum[i<<1]+sum[i<<1|1];
 28         mn[i]=min(mn[i<<1],mn[i<<1|1]),
 29         mn[i<<1]-=mn[i],mn[i<<1|1]-=mn[i];
 30         mx[i]=max(mx[i<<1],mx[i<<1|1]),
 31         mx[i<<1]-=mx[i],mx[i<<1|1]-=mx[i];
 32     }
 33 }
 34 inline void update_node(int x,int v,int A=0){
 35     x+=m,mx[x]+=v,mn[x]+=v,sum[x]+=v;
 36     for(;x>1;x>>=1){
 37         sum[x]+=v;
 38         A=min(mn[x],mn[x^1]);
 39         mn[x]-=A,mn[x^1]-=A,mn[x>>1]+=A;
 40         A=max(mx[x],mx[x^1]),
 41         mx[x]-=A,mx[x^1]-=A,mx[x>>1]+=A;
 42     }
 43 }
 44 inline void update_part(int s,int t,int v){
 45     int A=0,lc=0,rc=0,len=1;
 46     for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){
 47         if(s&1^1) add[s^1]+=v,lc+=len, mn[s^1]+=v,mx[s^1]+=v;
 48         if(t&1)    add[t^1]+=v,rc+=len, mn[t^1]+=v,mx[t^1]+=v;
 49         sum[s>>1]+=v*lc, sum[t>>1]+=v*rc;
 50         A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A,
 51         A=min(mn[t],mn[t^1]),mn[t]-=A,mn[t^1]-=A,mn[t>>1]+=A;
 52         A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A,
 53         A=max(mx[t],mx[t^1]),mx[t]-=A,mx[t^1]-=A,mx[t>>1]+=A;
 54     }
 55     for(lc+=rc;s;s>>=1){
 56         sum[s>>1]+=v*lc;
 57         A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A,
 58         A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A;
 59     }
 60 }
 61 inline int query_node(int x,int ans=0){
 62     for(x+=m;x;x>>=1) ans+=mn[x]; return ans;
 63 }
 64 inline int query_sum(int s,int t){
 65     int lc=0,rc=0,len=1,ans=0;
 66     for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){
 67         if(s&1^1) ans+=sum[s^1]+len*add[s^1],lc+=len;
 68         if(t&1) ans+=sum[t^1]+len*add[t^1],rc+=len;
 69         if(add[s>>1]) ans+=add[s>>1]*lc;
 70         if(add[t>>1]) ans+=add[t>>1]*rc; 
 71     }
 72     for(lc+=rc,s>>=1;s;s>>=1) if(add[s]) ans+=add[s]*lc;
 73     return ans;
 74 }
 75 inline int query_min(int s,int t,int L=0,int R=0,int ans=0){
 76     if(s==t) return query_node(s);
 77     for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){
 78         L+=mn[s],R+=mn[t];
 79         if(s&1^1) L=min(L,mn[s^1]);
 80         if(t&1) R=min(R,mn[t^1]);
 81     }
 82     for(ans=min(L,R),s>>=1;s;s>>=1) ans+=mn[s];
 83     return ans;
 84 }
 85 inline int query_max(int s,int t,int L=0,int R=0,int ans=0){
 86     if(s==t) return query_node(s);
 87     for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){
 88         L+=mx[s],R+=mx[t];
 89         if(s&1^1) L=max(L,mx[s^1]);
 90         if(t&1) R=max(R,mx[t^1]);
 91     }
 92     for(ans=max(L,R),s>>=1;s;s>>=1) ans+=mx[s];
 93     return ans;
 94 }
 95 
 96 signed main(){
 97     
 98     
 99     
100     
101     
102     return 0;
103 }
View Code

 

 

 

板子題?這個真沒有...(不過你可以拿普通線段樹的板子題等練手)

默默放上線段樹板子題的鏈接... 

  1. 線段樹 1 

  2. 線段樹 2

代碼

1.

 1 //by Judge
 2 #include<cstdio>
 3 #include<iostream>
 4 #define ll long long
 5 using namespace std;
 6 const int M=1e5+111;
 7 //#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
 8 char buf[1<<21],*p1=buf,*p2=buf;
 9 inline ll read(){
10     ll x=0,f=1; char c=getchar();
11     for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
12     for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f;
13 }
14 char sr[1<<21],z[20];int C=-1,Z;
15 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
16 inline void print(ll x){
17     if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
18     while(z[++Z]=x%10+48,x/=10);
19     while(sr[++C]=z[Z],--Z);sr[++C]='\n';
20 }
21 ll n,m,q;
22 ll sum[M<<2],add[M<<2];
23 inline void build(){
24     for(m=1;m<=n;m<<=1);
25     for(int i=m+1;i<=m+n;++i) sum[i]=read();
26     for(int i=m-1;i;--i) sum[i]=sum[i<<1]+sum[i<<1|1];
27 }
28 inline void update_part(int s,int t,ll v){
29     ll A=0,lc=0,rc=0,len=1;
30     for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){
31         if(s&1^1) add[s^1]+=v,lc+=len;
32         if(t&1)    add[t^1]+=v,rc+=len;
33         sum[s>>1]+=v*lc,sum[t>>1]+=v*rc;
34     } for(lc+=rc,s>>=1;s;s>>=1) sum[s]+=v*lc; 
35 }
36 inline ll query_sum(int s,int t){
37     ll lc=0,rc=0,len=1,ans=0;
38     for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){
39         if(s&1^1) ans+=sum[s^1]+len*add[s^1],lc+=len;
40         if(t&1) ans+=sum[t^1]+len*add[t^1],rc+=len;
41         if(add[s>>1]) ans+=add[s>>1]*lc;
42         if(add[t>>1]) ans+=add[t>>1]*rc; 
43     } for(lc+=rc,s>>=1;s;s>>=1) if(add[s]) ans+=add[s]*lc;
44     return ans;
45 }
46 signed main(){
47     n=read(),q=read(),build();
48     int opt,x,y; ll k;
49     while(q--){
50         opt=read(),x=read(),y=read();
51         if(opt&1) k=read(),update_part(x,y,k);
52         else print(query_sum(x,y));
53     } Ot(); return 0;
54 }
View Code

 

2.

emmm...實在是太晚啦(其實是沒有研究過區間乘),所以就...您就自個兒研究吧~~~

 

 

推薦例題

 

題目:  無聊的數列

 

其實這道題用普通線段樹 + 懶標記也可以做 (你可以試試?)

但是用了 zkw 之后...那個代碼量的差別,我都不想說什么...(誒?貌似普通線段樹用了標記永久化之后差不多也是這個碼量?

 

代碼

 1 //by Judge
 2 #include<cstdio>
 3 #include<iostream>
 4 using namespace std;
 5 const int M=1<<20;
 6 int n,m,q,opt,L,R,k,d;
 7 int a[M],lt[M],dt[M];
 8 //#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
 9 char buf[1<<21],*p1=buf,*p2=buf;
10 inline int read(){
11     int x=0,f=1; char c=getchar();
12     for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
13     for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f;
14 }
15 char sr[1<<21],z[20];int C=-1,Z;
16 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
17 inline void print(int x){
18     if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
19     while(z[++Z]=x%10+48,x/=10);
20     while(sr[++C]=z[Z],--Z);sr[++C]='\n'; Ot();
21 }
22 inline void build(){
23     for(int i=m;i;--i) lt[i]=lt[i<<1];
24 }
25 inline void update(int L,int R,int k,int d){ //update 還是蠻常規的
26     for(int l=L+m-1,r=R+m+1;l^r^1;l>>=1,r>>=1){
27         if(l&1^1) a[l^1]+=k+(lt[l^1]-L)*d,dt[l^1]+=d;
28         if(r&1) a[r^1]+=k+(lt[r^1]-L)*d,dt[r^1]+=d;
29     }
30 }
31 inline int query(int p,int res){ //query 感性理解一下:非葉子節點存儲的是附加值,也就是操作 1 當中加入的等差數列
32     for(int i=m+p;i;i>>=1) res+=a[i]+(p-lt[i])*dt[i];
33     return res;
34 }
35 int main(){
36     n=read(),q=read(); for(m=1;m<=n;m<<=1); printf("%d\n",m);
37     for(int i=1;i<=n;++i) a[m+i]=read(),lt[m+i]=i;
38     build();
39     while(q--){
40         opt=read();
41         if(opt&1) L=read(),R=read(),k=read(),d=read(),update(L,R,k,d);
42         else k=read(),print(query(k,0));
43     } Ot(); return 0;
44 }
View Code

 

 

 

 

 

 

最后推薦一下:  某位大佬的 blog (寫的也蠻詳細的但沒我詳細,emmm...但是他那片博客里的區間求最值是錯的,坑!)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM