樹狀數組


樹狀數組

一、適用范圍

  • 樹狀數組是一個查詢和修改復雜度都為 \(log(n)\) 的數據結構,常常用於查詢任意區間的所有元素之和。
  • 與前綴和的區別是支持動態修改, \(log(n)\) 的時間進行修改,\(log(n)\) 查詢。
  • 支持如下操作:
    • 單點修改區間查詢
    • 區間修改單點查詢
    • 區間修改區間查詢

二、算法原理

  1. 樹狀數組較好的利用了二進制。它的每個節點的值代表的是自己前面一些連續元素。至於到底是前面哪些元素,這就由這個節點的下標決定。

  1. 設節點的編號為 \(i\) ,那么:

\[c[i]=\sum_{j=i-lowbit(i)+1}^i a[j] \]

  1. 即可以推導出:

    C[1] = A[1]  # lowbit(1)個元素之和
    C[2] = C[1] + A[2] = A[1] + A[2]  # lowbit(2)個元素之和
    C[3] = A[3]  # lowbit(3)個元素之和
    C[4] = C[2] + C[3] +A[4] = A[1] + A[2] + A[3] + A[4] # lowbit(4)個元素之和
    C[5] = A[5]
    C[6] = C[5] + A[6] = A[5] + A[6]
    C[7] = A[7]
    C[8] = C[4] + C[6] + C[7] + A[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8]
    
  2. 顯然一個節點並不一定是代表自己前面所有元素的和。只有滿足 \(2^n\) 這樣的數才代表自己前面所有元素的和。

  3. 理解 \(lowbit\) 函數

    • 原碼:如果機器字長為 \(n\),那么一個數的原碼就是用一個 \(n\) 位的二進制數,其中最高位為符號位:正數為 \(0\),負數為 \(1\)。剩下的 \(n-1\) 位表示該數的絕對值。

    • 反碼:知道了原碼,那么你只需要具備區分 \(0\)\(1\) 的能力就可以輕松求出反碼,為什么呢?因為反碼就是在原碼的基礎上,符號位不變其他位按位取反(就是 \(0\)\(1\)\(1\)\(0\))就可以了。

    • 補碼也非常的簡單,就是在反碼的基礎上按照正常的加法運算加 \(1\) 。正數的補碼就是其本身。負數的補碼是在其原碼的基礎上符號位不變,其余各位取反,最后 \(+1\),即取反 \(+1\)

    • $lowbit(x)=x&-x $ :表示截取 \(x\) 二進制最右邊的 \(1\) 所表示的值,可以寫成函數或宏定義

    • 注意宏定義是括號,因為宏名只是起到一個替代作用,不加括號在運算時優先級會出問題

      //1. 宏定義,注意括號,不建議這樣寫,容易產生歧義
      #define lowbit(x) ((x) & -(x))
      //2. 函數寫法,推薦寫法:
      int lowbit(int x){return x & -x;}
      

三、 樹狀數組的操作

  1. \(update\) 更新操作

    • 因為樹狀數組 \(c[x]\) 維護的是一個或若干個連續數之和,當我們修改了 \(a[x]\) 之后,\(x\sim n\) 前綴和均發生了變化,所以除了\(c[x]\) 需要修改之外 \(x\) 的祖先節點也必須修改而 \(x\) 的父親節點為 \(x+lowbit(x)\),我們叫向上更新。

    • 把序列中第 \(i\) 個數增加 \(x\)\(sum[i]\sim sum[n]\) 均增加了 \(x\) ,所以我們只需把這個增量往上更新即可。如果,把 \(a[i]\) 修改成 \(x\),則我們向上更新 \(a[i]\) 的增量:\(x-a[i]\)

      //1. a[id] 增加 x while寫法
      void updata(int id,int x){
          while(id<=n){//向上更新,更新到n為止
              c[id]+=x;
              id+=lowbit(id);
          }
      }
      //2. a[id] 修改成 x  for寫法
      void updata(int id,int x){//或者傳遞參數是x=x-a[id],此時跟第一種寫法一樣
          for(int i=id;i<=n;i+=lowbit(i))
              c[i]+=x-a[id];
      }
      
  2. \(getsum\) 查詢操作

    • 因為樹狀數組維護的是一個能夠動態修改的前綴和,所以可以在 \(log(n)\) 的效率下求出前 \(n\) 項和\(sum[i]\)

    • 如果 \(i=2^j (j=0,1,..n)\), 此時最簡單,顯然有:\(sum[i]=c[i]\) ,如果 \(i\) 是其他的情況呢?

      • \(sum[5]=c[5]+c[4]\ (4=5-lowbit(5))\)
      • \(sum[15]=c[15]+c[14]+c[12]+c[8]\ (14=15-lowbit(15),12=14-lowbit(14),...)\)
    • 顯然,想要求出前 \(i\) 項前綴和 \(sum[i]\) ,只需沿着當前節點向下累加直到節點編號為 \(2^j\) 為止。我們叫向下求和。

      int getsum(int id){
          int tot=0;
          for(int i=id;i>0;i-=lowbit(i))
              tot+=c[i];
          return tot;
      }
      

四、求逆序對

  • 算法思想

    • 逆序對就是如果 \(i > j\ \&\&\ a[i] < a[j]\),這兩個就算一對逆序對。其實也就是對於每個數而言,找找排在其前面有多少個比自己大的數。
    • 我們用數組 \(c[i]\) 記錄在數 \(i\) 之前出現的在 \([i-lowbit[i],i]\) 的數的個數。
    • 所以我們只需要向下更新向上求和來求出逆序對的個數了。
    • 注意,我們維護的是序列數的數值的大小,所以序列元素值 $a[i]>0 $ ,且元素大小不宜太大,而且必須為整數。
  • $Code $

    #include <bits/stdc++.h>
    const int maxn=1e6+5;
    int n,ans,a[maxn],c[maxn];
    int lowbit(int x){return x & -x;}
    void modify(int i){
        for(;i;i-=lowbit(i)) c[i]+=1;
    }
    int getsum(int i){
        int tot=0;
        for(;i<=maxn;i+=lowbit(i)) tot+=c[i];
        return tot;
    }
    void Solve(){
        scanf("%d",&n);
        for(int i=1;i<=n;++i){
            scanf("%d",&a[i]);
            a[i]++; //避免a[i]-1=0
            ans+=getsum(a[i]-1);
            modify(a[i]);
        }
        printf("%d\n",ans);
    }
    int main(){
        Solve();
        return 0;
    }
    
  • 離散化版 \(Code\)

    #include <bits/stdc++.h>
    const int maxn=1e5+5;
    int a[maxn],b[maxn],c[maxn];
    int n,cnt;
    int lowbit(int x){return x & -x;}
    void updata(int i){
        for(;i;i-=lowbit(i)) c[i]+=1;
    }
    int getsum(int i){
        int tot=0;
        for(;i<=n;i+=lowbit(i)) tot+=c[i];
        return tot;
    }
    void Solve(){
        scanf("%d",&n);
        srand(time(0));
        for(int i=1;i<=n;++i){
            a[i]=rand()%n;
            b[i]=a[i];
            printf("%d ",a[i]);
        }
        printf("\n");
        std::sort(b+1,b+n+1);
        cnt=std::unique(b+1,b+n+1)-b;
        for(int i=1;i<=n;++i) a[i]=std::lower_bound(b+1,b+cnt,a[i])-b;
        int ans=0;
        for(int i=1;i<=n;++i){
            ans+=getsum(a[i]+1);
            updata(a[i]);
        }
        printf("%d\n",ans);
    }
    int main(){
        Solve();
        return 0;
    }
    

五、離散化

  1. 什么是離散化呢?

    • 很多時候,我們並不關心數組中每個值的大小,只關心它們的序的關系。
      • 在求數組的逆序對的時候,9 8 7 6 55 4 3 2 1 具有相同的逆序對
      • 我們只關心數組的每個數右邊有多少個比當前元素小的數,至於每個數有多大並不重要。
    • 通常我們把 一個具有 n 個 unique values 的數組映射到 range [1, n]的整數的操作叫做離散化。
    • 如果數組有重復元素,重復元素在離散化后的數組也需要具有相同的值。
  2. 離散化的兩種方法:

    • 方法一:lower_bound

      • 對原始數據進行備份,並對備份數組進行排序。

      • \(stl\)\(unique\) 函數對排序后的數組進行去重。

      • 二分查找原始數組里每個元素在去重后的備份數組中的位置,並把位置作為數組的新的值。

      • \(Code\)

        #include <bits/stdc++.h>
        const int maxn=1e5+5;
        int a[maxn],b[maxn];//a為原數組,b為備份數組
        int n,cnt;
        void Solve(){
            scanf("%d",&n);
            srand(time(0));
            for(int i=1;i<=n;++i){
                a[i]=rand()%(2*n);
                b[i]=a[i];
            }
            std::sort(b+1,b+n+1);//備份數組排序
           	cnt=std::unique(b+1,b+n+1)-b-1;//備份數組排序,cnt指向不重的最后一個元素
            for(int i=1;i<=n;++i) //二分查找a[i]在數組中的位置,並用相對大小代替原始值。
                a[i]=std::lower_bound(b+1,b+cnt+1,a[i])-b;   
        }
        int main(){
            Solve();
            return 0;
        }
        
      • unique 解析:

        • unique 函數的函數原型如下:

          iterator unique(iterator it_1,iterator it_2);
          
        • 這兩個參數表示對容器中 \([it\_1,it\_2)\) 范圍的元素進行去重,注意區間是前閉后開

        • 返回值是一個迭代器,它指向的是去重后容器中不重復序列的最后一個元素的下一個元素

        • unique 函數的去重過程實際上就是不停的把后面不重復的元素移到前面來,也可以說是用不重復的元素占領重復元素的位置

        • unique 函數實現過程等價於下面函數:

          iterator My_Unique (iterator first, iterator last){
              if (first==last) return last; 
              iterator result = first;//result指向最后一個不重復的最后一個元素
              while (++first != last){//遍歷整個序列
                  if (!(*result == *first)) //first和result指向的值不相等
                      *(++result)=*first;//把first指向的值移動到result的下一個位置
              }//如果first和result指向值相等,first往后遍歷。
              return ++result;//把不重復的最后一個元素的下一個位置的迭代器返回。
          }
          
        • unique 函數去重一般需要對序列進行排序,否則有可能不能真正的去重。

    • 方法二:排序之后,枚舉着放回原數組

      • 結構體存下原數和位置。

      • 對結構體數組按照元素的值進行排序

      • 枚舉排序后的數組,\(rank[id]=i\) 離散化數組。

      • \(Code\)

        #include <bits/stdc++.h>
        const int maxn=1e5;
        struct Node{
            int id,data;
            bool operator <(const Node &a)const{
                return data<a.data;
            }
        }a[maxn];
        int n,rank[maxn];
        void Solve(){
            scanf("%d",&n);
            srand(time(0));
            for(int i=1;i<=n;++i){
                a[i].id=i;
                a[i].data=rand()%n;
            }
            std::sort(a+1,a+n+1);
            for(int i=1;i<=n;++i) rank[a[i].id]=i;
            for(int i=1;i<=n;++i) printf("%d ",rank[i]);
        }
        int main(){
            Solve();
            return 0;
        }
        
      • 這種離散化方式沒有對相同元素去重,如果需要去重也比較麻煩,一般情況下用第一種方法進行離散化,簡單好寫還不容易出錯。

六、區間修改單點查詢

  1. 差分思想

    • 對一個 \(n\) 個元素的序列 \(\{a_1,a_2,...,a_n \}\) ,令 \(b_i=a_i-a_{i-1}\) ,產生新的序列 \(\{b_1,b_2,...,b_n\}\) ,我們稱 序列 \(b\) 為序列 \(a\) 的差分數組。
      • 序列 \(a=\{1,8,10,7,10\}\),則其差分序列 \(b=\{1,7,2,-3,3\}\)
      • 為了方便計算,序列編號一般為\(1\sim n\) ,且默認 \(a_0=0\)
    • 根據差分的定義,\(b_1=a_1-a_0,b_2=a_2-a_1,...,b_n=a_n-a_{n-1}\) ,由此我們很容易得出:\(a_i=\sum_{j=1}^{i} b_j\)
  2. 區間修改單點查詢

    • 如果我們用樹狀數組維護原序列的差分序列,我們很容易通過向上更新,向下求和的方式求出原序列的每一個元素。
    • 如果我們對原序列的 \([l,r]\) 區間的每一個元素增加 \(x\) ,此時我們只需對樹狀數組 \(c[l]\) 向上更新 \(x\) ,這樣向下查詢每一個元素的新的值的時候區間 \([l,n]\) 之間的元素值都增加了 \(x\) ,為了消除對區間 \([r+1,n]\) 之間的元素的影響,我們只需對樹狀數組 \(c[r+1]\) 處向上更新一個 \(-x\) 即可。
  3. 代碼實現:

    #include <bits/stdc++.h>
    const int maxn=1e6+5;
    typedef long long ll;
    ll a[maxn],b[maxn],c[maxn];
    int n;
    int lowbit(int x){return x & -x;}
    void updata(int i,ll x){
        for(;i<=n;i+=lowbit(i)) c[i]+=x;
    }
    ll getsum(int i){
        ll tot=0;
        for(;i;i-=lowbit(i)) tot+=c[i];
        return tot;
    }
    void Solve(){
        int Q;
        scanf("%d%d",&n,&Q);
        for(int i=1;i<=n;++i){
            scanf("%lld",&a[i]);
            b[i]=a[i]-a[i-1];//差分數組
            updata(i,b[i]);
        }
        int l,r;
        ll x;
        while(Q--){
            int flag;scanf("%d",&flag);
            if(flag==1){
                scanf("%d%d%lld",&l,&r,&x);
                updata(l,x);
                updata(r+1,-x);
            }
            else{
                int X;
                scanf("%d",&X);//查詢a[X]。
                printf("%lld\n",getsum(X));
            }
        }
    }
    int main(){
        Solve();
        return 0;
    }
    

七、區間修改區間查詢

  • 樹狀數組的區間查詢也是在差分的基礎上進行的,有上面的差分可知:

    • \(a_i=\sum_{j=1}^i b_j\)

    • 前綴和:\(sum_i=\sum_{j=1}^i a_i=\sum_{j=1}^i\sum_{k=1}^j b_k\)

    • \[\begin{aligned} sum_i&=a_1+a_2+...+a_i\\ &=b_1+(b_1+b_2)+...+(b_1+b_2+..+b_i)\\ &=i*b_1+(i-1)*b_2+...+2*b_{i-1}+b_i\\ &=i*(b_1+b_2+...+b_i)-(0*b_1+1*b_2+...+(i-1)*b_i)\\ &=i*\sum_{j=1}^i b_j-(0*b_1+1*b_2+...+(i-1)*b_i) \end{aligned} \]

    • 所以我們只需用一個樹狀數組維護 \(b_i\) ,一個樹狀數組維護 \((i-1)*b_i\) 即可。

    • \(Code\)

      #include <bits/stdc++.h>
      const int maxn=1e6+5;
      typedef long long ll;
      ll a[maxn],c1[maxn],c2[maxn];
      int n;
      int lowbit(int x){return x & -x;}
      void updata(int x,ll w){
          for(int i=x;i<=n;i+=lowbit(i)) {
              c1[i]+=w;//維護差分數組
              c2[i]+=(x-1)*w;//維護(i-1)*bi
          }
      }
      ll getsum(int x){
          ll tot=0;
          for(int i=x;i;i-=lowbit(i)) tot+=x*c1[i]-c2[i];
          return tot;
      }
      void Solve(){
          int Q;
          scanf("%d%d",&n,&Q);
          for(int i=1;i<=n;++i){
              scanf("%lld",&a[i]);
              updata(i,a[i]-a[i-1]);
          }
          int l,r;
          ll x;
          while(Q--){
              int flag;scanf("%d",&flag);
              if(flag==1){
                  scanf("%d%d%lld",&l,&r,&x);
                  updata(l,x);
                  updata(r+1,-x);
              }
              else{
                  scanf("%d%d",&l,&r);
                  printf("%lld\n",getsum(r)-getsum(l-1));
              }
          }
      }
      int main(){
          Solve();
          return 0;
      }
      


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM