zkw線段樹學習筆記


ZKW線段樹

應某迪要求,寫一篇數據結構學習筆記。

實際上還沒有學很多東西,只是一些基礎的操作。

zkw線段樹的學習資料,網上有很多,這里記錄的只是自己的一些理解。

 

 

建樹

1 inline void build(){
2     for(bit=1,n=read();bit<=n+1;bit<<=1);
3     for(int i=bit+1;i<=bit+n;++i) sum[i]=read();
4     for(int i=bit-1;i;--i) sum[i]=sum[i<<1]+sum[i<<1|1];
5 }

$zkw$線段樹構造了一棵完美二叉樹,只有最后一層葉子節點管轄的區間大小為1。

$zkw$線段樹是基於位運算的,對於節點$p$,$p<<1$為它的左兒子,$p<<1|1$為它的右兒子。

因為是一棵完美二叉樹,除掉葉子節點的部分一定為$2^k-1$的形式,將這個$2^k$記為$bit$,可以方便我們之后的操作。

其意義是,對於原序列的點$i$,可以直接得到對應線段樹上的節點$i+bit$。

注意這里我們忽略了$bit$也就是$2^k$這一個節點,以后再提。

同時建樹的一個細節是$bit$應當大於$n+1$,其原因也可以留到后面。

 

 

單點修改

1 inline void modify(int p,int val){
2     for(p+=bit;p;p>>=1) sum[p]+=val;
3 }

找到位置之后,直接修改一條祖先鏈。

 

 

區間修改

1 inline void modify(int l,int r,int val){
2     int lc=0,rc=0,len=1;
3     for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){
4         sum[l]+=lc*val; sum[r]+=rc*val;
5         if(~l&1) sum[l^1]+=len*val,add[l^1]+=val,lc+=len;
6         if(r&1) sum[r^1]+=len*val,add[r^1]+=val,rc+=len;
7     }
8     for(;l;l>>=1,r>>=1) sum[l]+=lc*val,sum[r]+=rc*val;
9 }

$lc$:當前左指針包含的區間長度。$rc$:當前右指針包含的區間長度。$len$:當前翻到的節點層管轄區間的長度。

這里我們將$l$,$r$都作為開區間。所以分別加$bit-1$,$bit+1$處理。

因為操作是自下而上進行的,$zkw$線段樹一般不維護懶標記,

因而我們用一個數組$add$進行標記永久化,表示這個區間的所有序列應該被加上這個值,顯然這個值是不能下傳的。

當$l$的最后一位為0,也就是說$l$指針為左兒子,那么l的右兄弟在當前修改的區間內。

同理$r$的左兄弟會在修改區間內。

當$l$,$r$兩個指針已經成為兄弟,也就是說二者在二進制下只有最后一位不同,即異或值為1,那么全部的修改操作已經完成,可以結束。

然而祖先鏈上的$sum$值仍然需要修改。

這里可以解釋,為什么$bit$應該大於$n+1$而不是$n$,為什么$bit+0$這個節點需要被空出來,因為我們需要開區間來進行操作。

然而似乎使$bit$僅保證大於$n$的打法是正確的,手玩確實沒有錯誤。

 

 

單點查詢

1 inline int query(int p){
2     int ans=0;
3     for(p+=bit,ans=sum[p],p>>=1;p;p>>=1) ans+=add[p];
4     return ans;
5 }

統計葉子節點的$sum$值,並不斷加上祖先鏈的$add$標記即可。

應當注意的是不要加上葉子節點的$add$標記,這個標記是無意義的。

 

 

區間查詢

 1 inline int query(int l,int r){
 2     int ans=0,lc=0,rc=0,len=1;
 3     for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){
 4         ans+=add[l]*lc+add[r]*rc;
 5         if(~l&1) ans+=sum[l^1],lc+=len;
 6         if(r&1) ans+=sum[r^1],rc+=len;
 7     }
 8     for(;l;l>>=1,r>>=1) ans+=add[l]*lc+add[r]*rc;
 9     return ans;
10 }

區間查詢的打法是類似於區間修改的。

首先將$l$,$r$設為開區間。

不斷翻祖先鏈,記得加上兄弟節點整體的$sum$值和祖先鏈上部分的$add$標記就可以了。

應當注意的是循環中統計$add$標記和兄弟$sum$值的順序不可交換,否則可能導致$lc$,$rc$變量維護的含義錯誤。

 

 

區間最值

思想大概是將兒子的最值不斷差分到父親身上。

因為不同的部分已經被差分掉,可以直接修改區間的最值。

應當注意的是,區間最值的求法與區間求和不同。

為了減少特判,原本的開區間被轉化為閉區間。

但這樣產生一個問題,如果查詢區間長度為1會導致一些問題:左右端點永遠不會成為兄弟,故導致了死循環。

所以要加一個單點查詢的特判。

因為要維護區間最值,修改操作同時也要不斷差分,於是打的麻煩了許多,代碼可以參考下面。

 

 

基礎操作

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int N=1e6+7;
 4 inline int read(register int x=0,register char ch=getchar(),register int f=0){
 5     while(!isdigit(ch)) f=ch=='-',ch=getchar();
 6     while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
 7     return f?-x:x;
 8 }
 9 int n,m,bit;
10 int sum[N<<2],add[N<<2],mn[N<<2],mx[N<<2];
11 inline void build(){
12     for(bit=1;bit<=n+1;bit<<=1);
13     for(int i=bit+1;i<=bit+n;++i) mx[i]=mn[i]=sum[i]=read();
14     for(int i=bit-1;i;--i){
15         sum[i]=sum[i<<1]+sum[i<<1|1];
16         mn[i]=min(mn[i<<1],mn[i<<1|1]); mn[i<<1]-=mn[i]; mn[i<<1|1]-=mn[i];
17         mx[i]=max(mx[i<<1],mx[i<<1|1]); mx[i<<1]-=mx[i]; mx[i<<1|1]-=mx[i];
18     }
19 }
20 inline int query(int p){
21     int ans=0;
22     for(p+=bit,ans=sum[p],p>>=1;p;p>>=1) ans+=add[p];
23     return ans;
24 }
25 inline int query(int l,int r){
26     int ans=0,lc=0,rc=0,len=1;
27     for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){
28         ans+=add[l]*lc+add[r]*rc;
29         if(~l&1) ans+=sum[l^1],lc+=len;
30         if(r&1) ans+=sum[r^1],rc+=len;
31     }
32     for(;l;l>>=1,r>>=1) ans+=add[l]*lc+add[r]*rc;
33     return ans;
34 }
35 inline int query_min(int l,int r){
36     if(l==r) return query(l);
37     int lans=0,rans=0;
38     for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){
39         lans+=mn[l]; rans+=mn[r];
40         if(~l&1) lans=min(lans,mn[l^1]);
41         if(r&1) rans=min(rans,mn[r^1]);
42     }
43     for(lans=min(lans+mn[l],rans+mn[r]),l>>=1;l;l>>=1) lans+=mn[l];
44     return lans;
45 }
46 inline int query_max(int l,int r){
47     if(l==r) return query(l);
48     int lans=0,rans=0;
49     for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){
50         lans+=mx[l]; rans+=mx[r];
51         if(~l&1) lans=max(lans,mx[l^1]);
52         if(r&1) rans=max(rans,mx[r^1]);
53     }
54     for(lans=max(lans+mx[l],rans+mx[r]),l>>=1;l;l>>=1) lans+=mx[l];
55     return lans;
56 }
57 inline void modify(int l,int r,int val){
58     int lc=0,rc=0,len=1,x;
59     for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){
60         sum[l]+=lc*val; sum[r]+=rc*val;
61         if(~l&1) sum[l^1]+=len*val,add[l^1]+=val,mn[l^1]+=val,lc+=len;
62         if(r&1) sum[r^1]+=len*val,add[r^1]+=val,mn[r^1]+=val,rc+=len;
63         x=min(mn[l],mn[l^1]); mn[l]-=x; mn[l^1]-=x; mn[l>>1]+=x;
64         x=min(mn[r],mn[r^1]); mn[r]-=x; mn[r^1]-=x; mn[r>>1]+=x;
65     }
66     for(;l;l>>=1,r>>=1){
67         sum[l]+=lc*val; sum[r]+=rc*val;
68         x=min(mn[l],mn[l^1]); mn[l]-=x; mn[l^1]-=x; mn[l>>1]+=x;
69         x=max(mx[l],mx[l^1]); mx[l]-=x; mx[l^1]-=x; mx[l>>1]+=x;
70     }
71 }
72 inline void modify(int p,int val){
73     int x;
74     for(p+=bit;p;p>>=1){
75         sum[p]+=val; mn[p]+=val; mx[p]+=val;
76         x=min(mn[p],mn[p^1]); mn[p]-=x; mn[p^1]-=x; mn[p>>1]+=x;
77         x=max(mx[p],mx[p^1]); mx[p]-=x; mx[p^1]-=x; mx[p>>1]+=x;
78     }
79 }
80 int main(){
81     n=read();
82     build();
83     return 0;
84 }

 

 

區間信息合並(山海經)

思想大概與區間查詢最值一致。

為了減少特判將開區間轉化為閉區間。

需要注意的是信息合並有左右的先后順序。

所以左右指針的寫法並不相同,最后將$l$,$r$掃過的信息合並就可以了。

 1 #include<bits/stdc++.h>
 2 const int N=100010;
 3 inline int read(register int x=0,register char ch=getchar(),bool f=0){
 4     while(!isdigit(ch)) f=ch=='-',ch=getchar();
 5     while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
 6     return f?-x:x;
 7 }
 8 struct Ans{
 9     int l,r,val;
10 };
11 struct Node{
12     Ans la,ra,mx,tot;
13 }s[N<<2];
14 int n,m,bit;
15 inline bool operator <(const Ans &a,const Ans &b){
16     return a.val<b.val||(a.val==b.val&&a.l>b.l)||(a.val==b.val&&a.l==b.l&&a.r>b.r);
17 }
18 inline Ans operator +(const Ans &a,const Ans &b){
19     return (Ans){a.l,b.r,a.val+b.val};
20 }
21 inline Node operator +(const Node &a,const Node &b){
22     return (Node){std::max(a.la,a.tot+b.la),std::max(b.ra,a.ra+b.tot),std::max(std::max(a.mx,b.mx),a.ra+b.la),a.tot+b.tot};
23 }
24 void build(){
25     for(bit=1;bit<=n+1;bit<<=1);
26     for(int i=bit+1;i<=bit+n;++i) s[i].tot=s[i].mx=s[i].la=s[i].ra=(Ans){i-bit,i-bit,read()};
27     for(int i=bit-1;i;--i) s[i]=s[i<<1]+s[i<<1|1];
28 }
29 Ans query(int r,int l){
30     if(l==r) return s[l+bit].mx;
31     Node L=s[l+bit],R=s[r+bit];
32     for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){
33         if(~l&1) L=L+s[l^1];
34         if(r&1) R=s[r^1]+R;
35     }
36     return (L+R).mx;
37 }
38 void print(const Ans &x){
39     printf("%d %d %d\n",x.l,x.r,x.val);
40 }
41 int main(){
42     n=read(); m=read(); build();
43     for(int i=1;i<=m;++i) print(query(read(),read()));
44     return 0;
45 }

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM