什么是主席樹
可持久化數據結構(Persistent data structure)就是利用函數式編程的思想使其支持詢問歷史版本、同時充分利用它們之間的共同數據來減少時間和空間消耗。
因此可持久化線段樹也叫函數式線段樹又叫主席樹。
可持久化數據結構
在算法執行的過程中,會發現在更新一個動態集合時,需要維護其過去的版本。這樣的集合稱為是可持久的。
實現持久集合的一種方法時每當該集合被修改時,就將其整個的復制下來,但是這種方法會降低執行速度並占用過多的空間。
考慮一個持久集合S。
如圖所示,對集合的每一個版本維護一個單獨的根,在修改數據時,只復制樹的一部分。
稱之為可持久化數據結構。
可持久化線段樹
令 T 表示一個結點,它的左兒子是 left(T),右兒子是 right(T)。
若 T 的范圍是 [L,R],那么 left(T) 的范圍是 [L,mid],right(T) 的范圍是 [mid+1,R]。
單點更新
我們要修改一個葉子結點的值,並且不能影響舊版本的結構。
在從根結點遞歸向下尋找目標結點時,將路徑上經過的結點都復制一份。
找到目標結點后,我們新建一個新的葉子結點,使它的值為修改后的版本,並將它的地址返回。
對於一個非葉子結點,它至多只有一個子結點會被修改,那么我們對將要被修改的子結點調用修改函數,那么就得到了它修改后的兒子。
在每一步都向上返回當前結點的地址,使父結點能夠接收到修改后的子結點。
在這個過程中,只有對新建的結點的操作,沒有對舊版本的數據進行修改。
區間查詢
從要查詢的版本的根節點開始,像查詢普通的線段樹那樣查詢即可。
延遲標記
...
區間第K小值問題
有n個數,多次詢問一個區間[L,R]中第k小的值是多少。
查詢[1,n]中的第K小值
我們先對數據進行離散化,然后按值域建立線段樹,線段樹中維護某個值域中的元素個數。
在線段樹的每個結點上用cnt記錄這一個值域中的元素個數。
那么要尋找第K小值,從根結點開始處理,若左兒子中表示的元素個數大於等於K,那么我們遞歸的處理左兒子,尋找左兒子中第K小的數;
若左兒子中的元素個數小於K,那么第K小的數在右兒子中,我們尋找右兒子中第K-(左兒子中的元素數)小的數。
查詢區間[L,R]中的第K小值
我們按照從1到n的順序依次將數據插入可持久化的線段樹中,將會得到n+1個版本的線段樹(包括初始化的版本),將其編號為0~n。
可以發現所有版本的線段樹都擁有相同的結構,它們同一個位置上的結點的含義都相同。
考慮第i個版本的線段樹的結點P,P中儲存的值表示[1,i]這個區間中,P結點的值域中所含的元素個數;
假設我們知道了[1,R]區間中P結點的值域中所含的元素個數,也知道[1,L-1]區間中P結點的值域中所包含的元素個數,顯然用第一個個數減去第二個個數,就可以得到[L,R]區間中的元素個數。
因此我們對於一個查詢[L,R],同步考慮兩個根root[L-1]與root[R],用它們同一個位置的結點的差值就表示了區間[L,R]中的元素個數,利用這個性質,從兩個根節點,向左右兒子中遞歸的查找第K小數即可。
POJ 2104 K-th Number (HDU 2665)
注意可持久化數據結構的內存開銷非常大,因此要注意盡可能的減少不必要的空間開支。

1 const int maxn=100001; 2 struct Node{ 3 int ls,rs; 4 int cnt; 5 }tr[maxn*20]; 6 int cur,rt[maxn]; 7 void init(){ 8 cur=0; 9 } 10 inline void push_up(int o){ 11 tr[o].cnt=tr[tr[o].ls].cnt+tr[tr[o].rs].cnt; 12 } 13 int build(int l,int r){ 14 int k=cur++; 15 if (l==r) { 16 tr[k].cnt=0; 17 return k; 18 } 19 int mid=(l+r)>>1; 20 tr[k].ls=build(l,mid); 21 tr[k].rs=build(mid+1,r); 22 push_up(k); 23 return k; 24 } 25 int update(int o,int l,int r,int pos,int val){ 26 int k=cur++; 27 tr[k]=tr[o]; 28 if (l==pos&&r==pos){ 29 tr[k].cnt+=val; 30 return k; 31 } 32 int mid=(l+r)>>1; 33 if (pos<=mid) tr[k].ls=update(tr[o].ls,l,mid,pos,val); 34 else tr[k].rs=update(tr[o].rs,mid+1,r,pos,val); 35 push_up(k); 36 return k; 37 } 38 int query(int l,int r,int o,int v,int kth){ 39 if (l==r) return l; 40 int mid=(l+r)>>1; 41 int res=tr[tr[v].ls].cnt-tr[tr[o].ls].cnt; 42 if (kth<=res) return query(l,mid,tr[o].ls,tr[v].ls,kth); 43 else return query(mid+1,r,tr[o].rs,tr[v].rs,kth-res); 44 }
常數優化的技巧
一種在常數上減小內存消耗的方法:
插入值時候先不要一次新建到底,能留住就留住,等到需要訪問子節點時候再建下去。
這樣理論內存復雜度依然是O(Nlg^2N),但因為實際上很多結點在查詢時候根本沒用到,所以內存能少用一些。
動態第K小值
每一棵線段樹是維護每一個序列前綴的值在任意區間的個數,如果還是按照靜態的來做的話,那么每一次修改都要遍歷O(n)棵樹,時間就是O(2*M*nlogn)->TLE。
考慮到前綴和,我們通過樹狀數組來優化,即樹狀數組套主席樹,每個節點都對應一棵主席樹,那么修改操作就只要修改logn棵樹,O(nlognlogn+Mlognlogn)時間是可以的,但是直接建樹要nlogn*logn(10^7)會MLE。
我們發現對於靜態的建樹我們只要nlogn個節點就可以了,而且對於修改操作,只是修改M次,每次改變倆個值(減去原先的,加上現在的)也就是說如果把所有初值都插入到樹狀數組里是不值得的,所以我們分兩部分來做,所有初值按照靜態來建,內存O(nlogn),而修改部分保存在樹狀數組中,每次修改logn棵樹,每次插入增加logn個節點O(M*logn*logn+nlogn)。
可用主席樹解決的問題
POJ 2104 K-th Number
入門題,求區間第K小數。

1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 const int maxn=100001; 7 struct Node{ 8 int ls,rs; 9 int cnt; 10 }tr[maxn*20]; 11 int cur,rt[maxn]; 12 void init(){ 13 cur=0; 14 } 15 inline void push_up(int o){ 16 tr[o].cnt=tr[tr[o].ls].cnt+tr[tr[o].rs].cnt; 17 } 18 int build(int l,int r){ 19 int k=cur++; 20 if (l==r) { 21 tr[k].cnt=0; 22 return k; 23 } 24 int mid=(l+r)>>1; 25 tr[k].ls=build(l,mid); 26 tr[k].rs=build(mid+1,r); 27 push_up(k); 28 return k; 29 } 30 int update(int o,int l,int r,int pos,int val){ 31 int k=cur++; 32 tr[k]=tr[o]; 33 if (l==pos&&r==pos){ 34 tr[k].cnt+=val; 35 return k; 36 } 37 int mid=(l+r)>>1; 38 if (pos<=mid) tr[k].ls=update(tr[o].ls,l,mid,pos,val); 39 else tr[k].rs=update(tr[o].rs,mid+1,r,pos,val); 40 push_up(k); 41 return k; 42 } 43 int query(int l,int r,int o,int v,int kth){ 44 if (l==r) return l; 45 int mid=(l+r)>>1; 46 int res=tr[tr[v].ls].cnt-tr[tr[o].ls].cnt; 47 if (kth<=res) return query(l,mid,tr[o].ls,tr[v].ls,kth); 48 else return query(mid+1,r,tr[o].rs,tr[v].rs,kth-res); 49 } 50 int b[maxn]; 51 int sortb[maxn]; 52 int main() 53 { 54 int n,m; 55 int T; 56 //scanf("%d",&T); 57 //while (T--){ 58 while (~scanf("%d%d",&n,&m)){ 59 init(); 60 for (int i=1;i<=n;i++){ 61 scanf("%d",&b[i]); 62 sortb[i]=b[i]; 63 } 64 sort(sortb+1,sortb+1+n); 65 int cnt=1; 66 for (int i=2;i<=n;i++){ 67 if (sortb[i]!=sortb[cnt]){ 68 sortb[++cnt]=sortb[i]; 69 } 70 } 71 rt[0]=build(1,cnt); 72 for (int i=1;i<=n;i++){ 73 int p=lower_bound(sortb+1,sortb+cnt+1,b[i])-sortb; 74 rt[i]=update(rt[i-1],1,cnt,p,1); 75 } 76 for (int i=0;i<m;i++){ 77 int a,b,k; 78 scanf("%d%d%d",&a,&b,&k); 79 int idx=query(1,cnt,rt[a-1],rt[b],k); 80 printf("%d\n",sortb[idx]); 81 } 82 } 83 return 0; 84 }
SPOJ 3267 D-query
求區間內不重復的數的個數。
掃描數列建立可持久化線段樹,第i個數若第一次出現,則在線段樹中的位置i加1;若不是第一次出現,將上次出現的位置減1,在本次位置加1。
對於每個詢問的區間 [L,R],在第R個版本上的線段樹只有前R個數,在線段樹上查詢位置L,對經過的區間中的和進行累計即可。

1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 #include <map> 6 using namespace std; 7 const int maxn=100001; 8 struct Node{ 9 int ls,rs; 10 int cnt; 11 }tr[maxn*20]; 12 int cur,rt[maxn]; 13 void init(){ 14 cur=0; 15 } 16 inline void push_up(int o){ 17 tr[o].cnt=tr[tr[o].ls].cnt+tr[tr[o].rs].cnt; 18 } 19 int build(int l,int r){ 20 int k=cur++; 21 if (l==r) { 22 tr[k].cnt=0; 23 return k; 24 } 25 int mid=(l+r)>>1; 26 tr[k].ls=build(l,mid); 27 tr[k].rs=build(mid+1,r); 28 push_up(k); 29 return k; 30 } 31 int update(int o,int l,int r,int pos,int val){ 32 int k=cur++; 33 tr[k]=tr[o]; 34 if (l==pos&&r==pos){ 35 tr[k].cnt+=val; 36 return k; 37 } 38 int mid=(l+r)>>1; 39 if (pos<=mid) tr[k].ls=update(tr[o].ls,l,mid,pos,val); 40 else tr[k].rs=update(tr[o].rs,mid+1,r,pos,val); 41 push_up(k); 42 return k; 43 } 44 int query(int l,int r,int o,int pos){ 45 if (l==r) return tr[o].cnt; 46 int mid=(l+r)>>1; 47 if (pos<=mid) return tr[tr[o].rs].cnt+query(l,mid,tr[o].ls,pos); 48 else return query(mid+1,r,tr[o].rs,pos); 49 } 50 int b[maxn]; 51 map<int,int> mp; 52 int main() 53 { 54 int n,m; 55 //int T; 56 //scanf("%d",&T); 57 //while (T--){ 58 while (~scanf("%d",&n)){ 59 mp.clear(); 60 init(); 61 for (int i=1;i<=n;i++){ 62 scanf("%d",&b[i]); 63 } 64 rt[0]=build(1,n); 65 for (int i=1;i<=n;i++){ 66 if (mp.find(b[i])==mp.end()){ 67 mp[b[i]]=i; 68 rt[i]=update(rt[i-1],1,n,i,1); 69 } 70 else{ 71 int tmp=update(rt[i-1],1,n,mp[b[i]],-1); 72 rt[i]=update(tmp,1,n,i,1); 73 } 74 mp[b[i]]=i; 75 } 76 scanf("%d",&m); 77 for (int i=0;i<m;i++){ 78 int a,b; 79 scanf("%d%d",&a,&b); 80 int ans=query(1,n,rt[b],a); 81 printf("%d\n",ans); 82 } 83 } 84 return 0; 85 }