ZKW線段樹
應某迪要求,寫一篇數據結構學習筆記。
實際上還沒有學很多東西,只是一些基礎的操作。
zkw線段樹的學習資料,網上有很多,這里記錄的只是自己的一些理解。
建樹
1 inline void build(){ 2 for(bit=1,n=read();bit<=n+1;bit<<=1); 3 for(int i=bit+1;i<=bit+n;++i) sum[i]=read(); 4 for(int i=bit-1;i;--i) sum[i]=sum[i<<1]+sum[i<<1|1]; 5 }
$zkw$線段樹構造了一棵完美二叉樹,只有最后一層葉子節點管轄的區間大小為1。
$zkw$線段樹是基於位運算的,對於節點$p$,$p<<1$為它的左兒子,$p<<1|1$為它的右兒子。
因為是一棵完美二叉樹,除掉葉子節點的部分一定為$2^k-1$的形式,將這個$2^k$記為$bit$,可以方便我們之后的操作。
其意義是,對於原序列的點$i$,可以直接得到對應線段樹上的節點$i+bit$。
注意這里我們忽略了$bit$也就是$2^k$這一個節點,以后再提。
同時建樹的一個細節是$bit$應當大於$n+1$,其原因也可以留到后面。
單點修改
1 inline void modify(int p,int val){ 2 for(p+=bit;p;p>>=1) sum[p]+=val; 3 }
找到位置之后,直接修改一條祖先鏈。
區間修改
1 inline void modify(int l,int r,int val){ 2 int lc=0,rc=0,len=1; 3 for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){ 4 sum[l]+=lc*val; sum[r]+=rc*val; 5 if(~l&1) sum[l^1]+=len*val,add[l^1]+=val,lc+=len; 6 if(r&1) sum[r^1]+=len*val,add[r^1]+=val,rc+=len; 7 } 8 for(;l;l>>=1,r>>=1) sum[l]+=lc*val,sum[r]+=rc*val; 9 }
$lc$:當前左指針包含的區間長度。$rc$:當前右指針包含的區間長度。$len$:當前翻到的節點層管轄區間的長度。
這里我們將$l$,$r$都作為開區間。所以分別加$bit-1$,$bit+1$處理。
因為操作是自下而上進行的,$zkw$線段樹一般不維護懶標記,
因而我們用一個數組$add$進行標記永久化,表示這個區間的所有序列應該被加上這個值,顯然這個值是不能下傳的。
當$l$的最后一位為0,也就是說$l$指針為左兒子,那么l的右兄弟在當前修改的區間內。
同理$r$的左兄弟會在修改區間內。
當$l$,$r$兩個指針已經成為兄弟,也就是說二者在二進制下只有最后一位不同,即異或值為1,那么全部的修改操作已經完成,可以結束。
然而祖先鏈上的$sum$值仍然需要修改。
這里可以解釋,為什么$bit$應該大於$n+1$而不是$n$,為什么$bit+0$這個節點需要被空出來,因為我們需要開區間來進行操作。
然而似乎使$bit$僅保證大於$n$的打法是正確的,手玩確實沒有錯誤。
單點查詢
1 inline int query(int p){ 2 int ans=0; 3 for(p+=bit,ans=sum[p],p>>=1;p;p>>=1) ans+=add[p]; 4 return ans; 5 }
統計葉子節點的$sum$值,並不斷加上祖先鏈的$add$標記即可。
應當注意的是不要加上葉子節點的$add$標記,這個標記是無意義的。
區間查詢
1 inline int query(int l,int r){ 2 int ans=0,lc=0,rc=0,len=1; 3 for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){ 4 ans+=add[l]*lc+add[r]*rc; 5 if(~l&1) ans+=sum[l^1],lc+=len; 6 if(r&1) ans+=sum[r^1],rc+=len; 7 } 8 for(;l;l>>=1,r>>=1) ans+=add[l]*lc+add[r]*rc; 9 return ans; 10 }
區間查詢的打法是類似於區間修改的。
首先將$l$,$r$設為開區間。
不斷翻祖先鏈,記得加上兄弟節點整體的$sum$值和祖先鏈上部分的$add$標記就可以了。
應當注意的是循環中統計$add$標記和兄弟$sum$值的順序不可交換,否則可能導致$lc$,$rc$變量維護的含義錯誤。
區間最值
思想大概是將兒子的最值不斷差分到父親身上。
因為不同的部分已經被差分掉,可以直接修改區間的最值。
應當注意的是,區間最值的求法與區間求和不同。
為了減少特判,原本的開區間被轉化為閉區間。
但這樣產生一個問題,如果查詢區間長度為1會導致一些問題:左右端點永遠不會成為兄弟,故導致了死循環。
所以要加一個單點查詢的特判。
因為要維護區間最值,修改操作同時也要不斷差分,於是打的麻煩了許多,代碼可以參考下面。
基礎操作
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int N=1e6+7; 4 inline int read(register int x=0,register char ch=getchar(),register int f=0){ 5 while(!isdigit(ch)) f=ch=='-',ch=getchar(); 6 while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); 7 return f?-x:x; 8 } 9 int n,m,bit; 10 int sum[N<<2],add[N<<2],mn[N<<2],mx[N<<2]; 11 inline void build(){ 12 for(bit=1;bit<=n+1;bit<<=1); 13 for(int i=bit+1;i<=bit+n;++i) mx[i]=mn[i]=sum[i]=read(); 14 for(int i=bit-1;i;--i){ 15 sum[i]=sum[i<<1]+sum[i<<1|1]; 16 mn[i]=min(mn[i<<1],mn[i<<1|1]); mn[i<<1]-=mn[i]; mn[i<<1|1]-=mn[i]; 17 mx[i]=max(mx[i<<1],mx[i<<1|1]); mx[i<<1]-=mx[i]; mx[i<<1|1]-=mx[i]; 18 } 19 } 20 inline int query(int p){ 21 int ans=0; 22 for(p+=bit,ans=sum[p],p>>=1;p;p>>=1) ans+=add[p]; 23 return ans; 24 } 25 inline int query(int l,int r){ 26 int ans=0,lc=0,rc=0,len=1; 27 for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){ 28 ans+=add[l]*lc+add[r]*rc; 29 if(~l&1) ans+=sum[l^1],lc+=len; 30 if(r&1) ans+=sum[r^1],rc+=len; 31 } 32 for(;l;l>>=1,r>>=1) ans+=add[l]*lc+add[r]*rc; 33 return ans; 34 } 35 inline int query_min(int l,int r){ 36 if(l==r) return query(l); 37 int lans=0,rans=0; 38 for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){ 39 lans+=mn[l]; rans+=mn[r]; 40 if(~l&1) lans=min(lans,mn[l^1]); 41 if(r&1) rans=min(rans,mn[r^1]); 42 } 43 for(lans=min(lans+mn[l],rans+mn[r]),l>>=1;l;l>>=1) lans+=mn[l]; 44 return lans; 45 } 46 inline int query_max(int l,int r){ 47 if(l==r) return query(l); 48 int lans=0,rans=0; 49 for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){ 50 lans+=mx[l]; rans+=mx[r]; 51 if(~l&1) lans=max(lans,mx[l^1]); 52 if(r&1) rans=max(rans,mx[r^1]); 53 } 54 for(lans=max(lans+mx[l],rans+mx[r]),l>>=1;l;l>>=1) lans+=mx[l]; 55 return lans; 56 } 57 inline void modify(int l,int r,int val){ 58 int lc=0,rc=0,len=1,x; 59 for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){ 60 sum[l]+=lc*val; sum[r]+=rc*val; 61 if(~l&1) sum[l^1]+=len*val,add[l^1]+=val,mn[l^1]+=val,lc+=len; 62 if(r&1) sum[r^1]+=len*val,add[r^1]+=val,mn[r^1]+=val,rc+=len; 63 x=min(mn[l],mn[l^1]); mn[l]-=x; mn[l^1]-=x; mn[l>>1]+=x; 64 x=min(mn[r],mn[r^1]); mn[r]-=x; mn[r^1]-=x; mn[r>>1]+=x; 65 } 66 for(;l;l>>=1,r>>=1){ 67 sum[l]+=lc*val; sum[r]+=rc*val; 68 x=min(mn[l],mn[l^1]); mn[l]-=x; mn[l^1]-=x; mn[l>>1]+=x; 69 x=max(mx[l],mx[l^1]); mx[l]-=x; mx[l^1]-=x; mx[l>>1]+=x; 70 } 71 } 72 inline void modify(int p,int val){ 73 int x; 74 for(p+=bit;p;p>>=1){ 75 sum[p]+=val; mn[p]+=val; mx[p]+=val; 76 x=min(mn[p],mn[p^1]); mn[p]-=x; mn[p^1]-=x; mn[p>>1]+=x; 77 x=max(mx[p],mx[p^1]); mx[p]-=x; mx[p^1]-=x; mx[p>>1]+=x; 78 } 79 } 80 int main(){ 81 n=read(); 82 build(); 83 return 0; 84 }
區間信息合並(山海經)
思想大概與區間查詢最值一致。
為了減少特判將開區間轉化為閉區間。
需要注意的是信息合並有左右的先后順序。
所以左右指針的寫法並不相同,最后將$l$,$r$掃過的信息合並就可以了。
1 #include<bits/stdc++.h> 2 const int N=100010; 3 inline int read(register int x=0,register char ch=getchar(),bool f=0){ 4 while(!isdigit(ch)) f=ch=='-',ch=getchar(); 5 while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); 6 return f?-x:x; 7 } 8 struct Ans{ 9 int l,r,val; 10 }; 11 struct Node{ 12 Ans la,ra,mx,tot; 13 }s[N<<2]; 14 int n,m,bit; 15 inline bool operator <(const Ans &a,const Ans &b){ 16 return a.val<b.val||(a.val==b.val&&a.l>b.l)||(a.val==b.val&&a.l==b.l&&a.r>b.r); 17 } 18 inline Ans operator +(const Ans &a,const Ans &b){ 19 return (Ans){a.l,b.r,a.val+b.val}; 20 } 21 inline Node operator +(const Node &a,const Node &b){ 22 return (Node){std::max(a.la,a.tot+b.la),std::max(b.ra,a.ra+b.tot),std::max(std::max(a.mx,b.mx),a.ra+b.la),a.tot+b.tot}; 23 } 24 void build(){ 25 for(bit=1;bit<=n+1;bit<<=1); 26 for(int i=bit+1;i<=bit+n;++i) s[i].tot=s[i].mx=s[i].la=s[i].ra=(Ans){i-bit,i-bit,read()}; 27 for(int i=bit-1;i;--i) s[i]=s[i<<1]+s[i<<1|1]; 28 } 29 Ans query(int r,int l){ 30 if(l==r) return s[l+bit].mx; 31 Node L=s[l+bit],R=s[r+bit]; 32 for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){ 33 if(~l&1) L=L+s[l^1]; 34 if(r&1) R=s[r^1]+R; 35 } 36 return (L+R).mx; 37 } 38 void print(const Ans &x){ 39 printf("%d %d %d\n",x.l,x.r,x.val); 40 } 41 int main(){ 42 n=read(); m=read(); build(); 43 for(int i=1;i<=m;++i) print(query(read(),read())); 44 return 0; 45 }
