樹狀數組從前往后求和,用來解第k大(或小)的數 poj 2985 The k-th Largest Group


 

 

來自http://www.cnblogs.com/oa414/archive/2011/07/21/2113234.html的啟發,

看上述博客如何求第k大的數時,被其第二份代碼影響,感覺很巧妙,於是研究了一下,搞懂后頓時神清氣爽啊。。。

還是看這張經典的圖吧,知識在圖上就變得形象多了

 現在假設要求sum[a]的值,一般我們都是從后往前求和,如a=15

15-lowbit(15)=14;

14-lowbit(14)=12;

12-lowbit(12)=8;

8-lowbit(b)=0;

答案就是sum[15]+sum[14]+sum[12]+sum[8];

現在我們可以這樣來求,從不超過15的只有一個1的最大二進制數開始,也可以理解為指數從log(15)取整開始,即3,2的3次等於8,依次加上2的2次,2的1次,2的0次,數字依次為8,12,14,15,也就是把普通的求和過程反向。

好了,方向求和有什么好處呢?

在求第k大的數的時候就派上用場了,雖然還有很多其他方法可以解決第k大的數,但樹狀數組無疑是最優雅的方法了

下面就以poj 2418這一題來簡單說一下怎么求第k大的數

由於樹狀數組記錄的是比當前元素小的數的個數,所以可以先把求第k大的數轉換為求第num-k+1小的數,num是總的數的個數

int find_kth(int k)//太神奇了(大概是以前沒有完全領會),log(n)復雜度
{

int ans = 0, cnt = 0, i;
for (i = 20; i >= 0; i--)//利用二進制的思想,把答案用一個二進制數來表示
{

ans += (1 << i);
if (ans >= maxn|| cnt + c[ans] >= k)
//這里大於等於k的原因是可能會有很多個數都滿足cnt + c[ans] >= k,所以找的是最大的滿足cnt+c[ans]<k的ans
ans -= (1 << i);

else
cnt += c[ans];//cnt用來累加比當前ans小的總組數
}//求出的ans是累加和(即小於等於ans的數的個數)小於k的情況下ans的最大值,所以ans+1就是第k大的數
return ans + 1;

}


完整代碼即詳細注釋

 

 

View Code
#include<stdio.h>
#include<string.h>
#define maxn 300000
int a[maxn],c[maxn],p[maxn];//值為i的數有i個
int find(int x){return x==p[x] ? x : p[x]=find(p[x]);}
int lowbit(int x){
return x&-x;
}
void update(int x,int d){
for(;x<=maxn;x+=lowbit(x))
c[x]+=d;
}//因為是從左往右手動求和了,所以也不需要sum()操作了
int find_kth(int k)//太神奇了(大概是以前沒有完全領會),log(n)復雜度
{
int ans = 0, cnt = 0, i;
for (i = 20; i >= 0; i--)//利用二進制的思想,把答案用一個二進制數來表示
{
ans += (1 << i);
if (ans >= maxn|| cnt + c[ans] >= k)
//這里大於等於k的原因是可能大小為ans的數不在c[ans]的控制范圍之內,所以這里求的是 < k
ans -= (1 << i);
else
cnt += c[ans];//cnt用來累加比當前ans小的總組數
}//求出的ans是累加和(即小於等於ans的數的個數)小於k的情況下ans的最大值,所以ans+1就是第k大的數
return ans + 1;
}
/*
因為要求第k小的數,所以要從左往右加過去,
上述過程其實就是把樹狀數組的求和操作逆向,從左往右求和,
邊求和邊判斷控制范圍內比當前值要小的數是否超過或等於k,如果是則跳回兄弟節點(ans-=(1<<i))
如8+4=12,假如12不滿足要求,則重新變回8,下一次就加2,8+2=10,即跳到10控制的位置
上述累加過程不會重復計算,因為
比如15=8+4+2+1,數字依次為8 12 14 15,每次累加后的值都與前面的值無關,i小於其二進制末尾0的個數
即c[8] 、c[12] 、c[14]、 c[15]相加的話不會重復計算,再如11=8+2+1;數字依次為8 10 11,c[8],c[10],c[11]
各自控制着自己的范圍,不會重復累加,所以就可以用cnt來累加前面的結果,最后cnt+c[ans]就表示了值<=ans的個數
簡言之:上述的各個數字兩兩間控制的范圍不會相交
*/
int main()
{
int i,n,m,q,x,y,k,l,r;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++) p[i]=i;
for(i=1;i<=n;i++) a[i]=1;
update(1,n);//初始狀態值為1的數有n個
int num=n;
for(i=1;i<=m;i++)
{
scanf("%d",&q);
if(q==0)
{
scanf("%d%d",&x,&y);
x=find(x);
y=find(y);
if(x==y) continue;
update(a[x],-1);
update(a[y],-1);
update(a[x]+a[y],1);
p[y]=x;
a[x]+=a[y];
num--;//合並集合
}
else
{
scanf("%d",&k);
k=num-k+1;//轉換為找第k小的數
printf("%d\n",find_kth(k));
}
}
return 0;
}

 



二分做法

View Code
#include<stdio.h>
#include<string.h>
#define maxn 300000
int a[maxn],c[maxn],p[maxn];//值為i的數有i個
int find(int x){return x==p[x] ? x : p[x]=find(p[x]);}
int lowbit(int x){
return x&-x;
}
void update(int x,int d){
for(;x<=maxn;x+=lowbit(x))
c[x]+=d;
}
int sum(int x){
int ans=0;
for(;x>0;x-=lowbit(x))
ans+=c[x];
return ans;
}
int main()
{
int i,n,m,q,x,y,k,l,r;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++) p[i]=i;
for(i=1;i<=n;i++) a[i]=1;
update(1,n);//初始狀態值為1的數有n個
int num=n;
for(i=1;i<=m;i++)
{
scanf("%d",&q);
if(q==0)
{
scanf("%d%d",&x,&y);
x=find(x);
y=find(y);
if(x==y) continue;
update(a[x],-1);
update(a[y],-1);
update(a[x]+a[y],1);
p[y]=x;
a[x]+=a[y];
num--;//合並集合
}
else
{
scanf("%d",&k);
k=num-k+1;//轉換為找第k小的數
l=1;
r=n;
while(l <= r)
{
int mid=(l+r)>>1;
if(sum(mid) >= k) r=mid-1;//盡量往左逼近
else l=mid+1;
}
printf("%d\n",l);
}
}
return 0;
}





好像還可以用平衡樹,線段樹等來做,改天再補上

treap寫法:比樹狀數組還快

View Code
#include<cstdio>
#include<set>
#include<cstdlib>
#include<cstring>
using namespace std;
const int maxn = 300010;
#define L ch[rt][0]
#define R ch[rt][1]
int ch[maxn][2], aux[maxn] , num[maxn] , size[maxn] , cnt[maxn];
int val[maxn];
int tot,rt;
inline void init(){
    size[0]=0;
    rt = tot = 0;
    ch[0][0] = ch[0][1] = 0;
    aux[0] = 0;
}
inline void pushup(int rt){
    size[rt]=cnt[rt]+size[L]+size[R];
}
inline void Rotate(int &rt,int f){//f=1:右旋  f=0:左旋
    int t = ch[rt][!f];
    ch[rt][!f] = ch[t][f];
    ch[t][f] = rt;
    pushup(rt);pushup(t);
    rt = t;
}
void insert(int &rt,int key){
    if(!rt) {
        rt = ++tot;
        val[rt] = key; L = R = 0; size[rt]=cnt[rt]=1;
        aux[rt] = ( rand() << 14 ) + rand();
        return ;
    }
    if(key==val[rt]) {
        ++cnt[rt];
    }else if(key < val[rt]){
        insert(L , key);
        if( aux[L] < aux[rt] ) Rotate(rt,1);
    }else {
        insert(R , key);
        if( aux[R] < aux[rt] ) Rotate(rt,0);
    }
    pushup(rt);
}
void treap_delete(int &rt){//real deletion
    if(!L || !R){
         rt=L?L:R;  
    }else {
        if(aux[L] < aux[R]){
            Rotate(rt,1);
            treap_delete(R);
        }else {
            Rotate(rt,0);
            treap_delete(L);
        }
    }
}
void del(int &rt , int key){//lazy deletion 
    if(key == val[rt]) {
        cnt[rt]--;
        size[rt]--;
        if(cnt[rt]==0) 
            treap_delete(rt);
    }
    else {
        if(key < val[rt])
            del(L,key);
        else     
            del(R,key);
        size[rt]--;
    }
}
int find(int rt,int key){
    if(!rt) return 0;
    else if(key < val[rt])  return find(L,key);
    else if(key > val[rt])  return find(R,key);
    else return cnt[rt];
}
//找后繼結點
void succ(int rt,int key,int &ans){//找>=key的第一個結點,即后繼結點
    if(!rt)  return ;
    if(val[rt] >= key){
        ans=val[rt];
        succ(L,key,ans);
    }else
        succ(R,key,ans);
}
//找前驅結點
void pre(int rt,int key,int &ans){
    if(!rt) return ;
    if(val[rt]<=key) {
        ans=val[rt];
        succ(R,key,ans);
    }else 
        succ(L,key,ans);
}
int getmin(int rt){
    while(L) rt=L;    return val[rt];
}
int getmax(int rt){
    while(R) rt=R;   return val[rt];
}
//找第k小的數
int find_kth(int rt,int k){
    if(k<size[L]+1)
        return find_kth(L,k);
    else if(k>size[L]+cnt[rt])
        return find_kth(R,k-size[L]-cnt[rt]);
    else return val[rt];
}
//確定key的排名
int treap_rank(int rt,int key,int cur){//cur:當前已知比要求元素(key)小的數的個數
    if(key == val[rt])  
        return size[L] + cur + 1;
    else if(key < val[rt])
        treap_rank(L,key,cur);
    else 
        treap_rank(R,key,cur+size[L]+cnt[rt]);
}
int a[maxn],c[maxn],p[maxn];//值為i的數有i個
int find(int x){return x==p[x] ? x : p[x]=find(p[x]);}
int main()
{
    init();
    int i,n,m,q,x,y,k,l,r;
    scanf("%d%d",&n,&m);
    for(i=1;i<=n;i++) p[i]=i;
    for(i=1;i<=n;i++) a[i]=1;
    for(int i=1;i<=n;i++) insert(rt,1);
    int num=n;
    for(i=1;i<=m;i++)
    {
        scanf("%d",&q);
        if(q==0)
        {
            scanf("%d%d",&x,&y);
            x=find(x);
            y=find(y);
            if(x==y) continue;
            del(rt,a[x]);
            del(rt,a[y]);
            insert(rt,a[x]+a[y]);
            p[y]=x;
            a[x]+=a[y];
            num--;//合並集合
        }
        else 
        {
            scanf("%d",&k);
            k=num-k+1;//轉換為找第k小的數
            printf("%d\n",find_kth(rt,k));
        }
    }
    return 0;
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM