【HDU 4747 Mex】線段數


題目鏈接:http://acm.hdu.edu.cn/showproblem.php?pid=4747

 

題意:有一組序列a[i](1<=i<=N), 讓你求所有的mex(l,r), mex(l,r)表示區間[l,r]中最小的未在序列中出現的非負整數。

 

思路:冥思苦想半天無想法,白做了那么多線段樹。 很明顯的維護區間問題,容易想到線段樹,比較難想到操作。 枚舉一個序列的所mex(1,i),mex(2,i)……可以發現序列mex(x,i)是一個單調遞增序列,我們需要求得就是所有以x開頭的序列和,mex(x,i)(x<=i<=n)。這點確定了就好辦了,記錄每個位置的數后面最早重復出現的位置next[x],如果無則為設n+1。那么我們就可以發現,當第x個數所對應的序列 mex(x,i)(x<=i<=n)所對應的序列求完之后,刪去此位置的數,位置x+1~next[x]-1序列中mex值大於a[x]的都改為a[x],因為a[x]沒有了,下一個a[x]還未出現,所以可以證明這樣做是正確的。從1到n掃一遍亦求出了所有的mex()。

基本上所有的操作都可以用到線段樹。開始沒有想到一點的是如何找序列中剛好大於a[x]的位置,並且此位置到next[x]-1賦值為a[x],怎么都沒想到log(n)的操作,其實這里依然可以用到線段樹,因為序列是單調遞增的,另開一個區間維護序列mavv[u]表示區間中最大的mex值,隨着詢問以及其他操作成段更新即可。

 

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cmath>
  4 #include <map>
  5 #include <algorithm>
  6 #include <cstring>
  7 #include <sstream>
  8 using namespace std;
  9 
 10 #define lz 2*u,l,mid
 11 #define rz 2*u+1,mid+1,r
 12 typedef long long lld;
 13 const int maxn=222222;
 14 int a[maxn], b[maxn], next[maxn];
 15 lld sum[4*maxn], mavv[4*maxn], flag[4*maxn];
 16 map<int,int>mp;
 17 
 18 void push_up(int u, int l, int r)
 19 {
 20     sum[u]=sum[2*u]+sum[2*u+1];
 21     mavv[u]=mavv[2*u+1];
 22 }
 23 
 24 void push_down(int u, int l, int r)
 25 {
 26     int mid=(l+r)>>1;
 27     if(flag[u]!=-1)
 28     {
 29         flag[2*u]=flag[2*u+1]=flag[u];
 30         mavv[2*u]=mavv[2*u+1]=flag[u];
 31         sum[2*u]=(lld)(mid-l+1)*flag[u];
 32         sum[2*u+1]=(lld)(r-mid)*flag[u];
 33         flag[u]=-1;
 34     }
 35 }
 36 
 37 void build(int u, int l, int r)
 38 {
 39     flag[u]=-1;
 40     int mid=(l+r)>>1;
 41     if(l==r)
 42     {
 43         sum[u]=mavv[u]=b[l];
 44         return ;
 45     }
 46     build(lz);
 47     build(rz);
 48     push_up(u,l,r);
 49 }
 50 
 51 void Update(int u, int l, int r, int tl, int tr, int val)
 52 {
 53     if(tl>tr) return ;
 54     if(tl<=l&&r<=tr)
 55     {
 56         mavv[u]=val;
 57         sum[u]=(lld)val*(r-l+1);
 58         flag[u]=val;
 59         return ;
 60     }
 61     push_down(u,l,r);
 62     int mid=(l+r)>>1;
 63     if(tr<=mid) Update(lz,tl,tr,val);
 64     else if(tl>mid) Update(rz,tl,tr,val);
 65     else
 66     {
 67         Update(lz,tl,mid,val);
 68         Update(rz,mid+1,tr,val);
 69     }
 70     push_up(u,l,r);
 71 }
 72 
 73 int find(int u, int l, int r, int tmp)
 74 {
 75     if(l==r) return l;
 76     push_down(u,l,r);
 77     int mid=(l+r)>>1;
 78     if(mavv[2*u]>tmp) return find(lz,tmp);
 79     else return find(rz,tmp);
 80 }
 81 
 82 int main()
 83 {
 84     int n;
 85     while(cin >> n,n)
 86     {
 87         for(int i=1; i<=n; i++) scanf("%d",a+i);
 88         mp.clear();
 89         for(int i=n; i>=1; i--)
 90         {
 91             if(mp[ a[i] ]) next[i]=mp[ a[i] ];
 92             else next[i]=n+1;
 93             mp[ a[i] ]=i;
 94         }
 95         mp.clear();
 96         int x=0;
 97         for(int i=1; i<=n; i++)
 98         {
 99             mp[ a[i] ]=1;
100             while(mp[x]) ++x;
101             b[i]=x;
102         }
103         build(1,1,n);
104         lld ans=0;
105         for(int i=1; i<=n; i++)
106         {
107             ans+=sum[1];
108             if(mavv[1]>a[i])
109             {
110                 int id=find(1,1,n,a[i]);
111                 Update(1,1,n,max(id,i+1),next[i]-1,a[i]);
112             }
113             Update(1,1,n,i,i,0);
114         }
115         cout << ans <<endl;
116     }
117 }
View Code

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM