莫隊算法 (Mo's Algorithm)


省賽交了不熟莫隊的學費之后,決定寫篇博客復習一下。由於本人非常鄙視此類暴力算法(因為涉及分塊,感覺很不優美,而且我分塊姿勢也不熟練),於是一直沒有重視,省賽就被教育了……

比如GDCPC2019廣東省賽就有這么一道題:

  給定n,m,k,一個長度為n的數組,m次詢問。每次詢問給出區間[l,r],要求計算區間中有多少個a[i]、a[j]滿足i<j && abs(a[i]-a[j])<=k。n,m<=27000, k<=1e9。TL=1s,OL=256mb。

當時第一反應是線段樹,然而馬上感覺非常不可做,因為線段樹只能維護符合線性性的屬性(比如區間最值、區間和之類的);而這道題如果用線段樹維護,等於要把左右子區間的信息再一次合並,左右子區間的信息完全沒有貢獻(因為不符合線性性,不能通過線性操作來維護當前節點的信息),光是maintain的時間復雜度就已經達到了O(n^2*logn),更不要說建樹和查詢了,完全fake做法。

那應該怎么做呢?由於這道題沒有修改操作,所以可以考慮離線處理,最明顯的處理區間問題的離線算法就是莫隊了……莫隊本質上是一種分塊算法,通常用來離線解決只查詢不修改的區間問題。

讓我們看一道更水的題來學習莫隊算法:

  【例題】給定一個大小為N的數組,數組中所有元素的大小a[i]<=N。你需要回答M個查詢。每個查詢的形式是L,R。你需要回答在范圍[ L,R ]這個區間內數字的出現次數剛好是k的數字種數。k<=n,m<=30000,1<=L<r<=n。TL=1s,OL=256mb。

這個題能用線段樹做嗎?依然不行。原因還是我上面提到的問題:在維護segT[currentPosition]的信息時,不能簡單地通過segT[leftSon]和segT[rightSon]的信息來計算得出。每次維護segT[currentPosition]的信息時,都要重新處理左右子區間的信息來確定相互之間的影響,等於一棵fake線段樹。

既然不能線段樹,我們怎么思考這個問題呢?不妨假設我已經處理完了某個區間[l,r]的信息,現在我要處理區間[l,r+1]的信息,那我們就可以在O(1)時間內處理完畢:只需開一個cnt數組來記錄數字的出現次數,再cnt [ a [ r+1 ] ] ++即可。這樣,我們就可以知道區間[l±1,r±1]的信息,就可以用莫隊算法了。

莫隊算法的一種實現方式是:離線得到了一堆需要處理的區間后,合理地安排這些區間計算的次序以得到一個較優的復雜度。假設我們當前已知區間[l1,r1]的信息,要轉移到區間[l2,r2],由於l、r只能單步轉移,那么時間復雜度為O(abs(l1-l2)+abs(r1-r2))。要是把這兩個區間看作是平面上的兩個整點,就變成了曼哈頓距離。整體的時間復雜度必然為整棵樹的曼哈頓距離之和。顯然當整棵樹為MST時,時間復雜度最優。那么在求解答案時只要根據曼哈頓MST的dfs序求解即可。

還有另一種暴力寫法:先對序列分塊,然后以詢問左端點所在的分塊的序號為第一關鍵字,右端點的大小為第二關鍵字進行排序,按照排序好的順序計算,復雜度就會大大降低。

  1. 分塊相同時,右端點遞增是O(N)的,分塊共有O(\sqrt{N} )個,復雜度為O(N^{1.5} )
  2. 分塊轉移時,右端點最多變化N,分塊共有O(\sqrt{N} )個,復雜度為O(N^{1.5} )
  3. 分塊相同時,左端點最多變化\sqrt{N} ,分塊轉移時,左端點最多變化2\sqrt{N} ,共有N個詢問,復雜度為O(N^{1.5} )

故總時間復雜度O(N^{1.5} )

以例題為例。不妨給出一組樣例:n=9 (數組內容忽略,不重要) ,m=8。查詢區間為:[2,3], [1,4], [4,5], [1,6], [7,9], [8,9], [5,8], [6,8]。

因為最大范圍不超過n,所以我們以(int)sqrt(n)為大小,對區間進行分塊。在每一個塊中,我們按r從小到大的順序排列。所以上面的排序結果是:

{ (2,3) (1,4) (1,6),(4,5) (5,8) (6,8),(7,9) (8,9) }

這一步的排序可以這樣實現:

1 unit=(int)sqrt(n);
2 bool operator<(const Interval &rhs)const
3 {
4     if (l / unit != rhs.l / unit) return l / unit < rhs.l / unit;
5     else return r < rhs.r;
6 }

考慮到在同一個塊的時候,由於L的范圍確定,故每次L的偏移量是O(sqrt(n));但是R的范圍沒有確定,故R的偏移量是O(n)。

那么從一個塊到另一個塊呢? 顯然我們不用考慮R的偏移量,依然是O(n),而L明顯最多也是2*sqrt(n)。在這種情況下,很快就會到下下一塊。所以也是O(sqrt(n))。

由於有sqrt(n)個塊,所以R的總偏移量是O(n*sqrt(n)),而M個詢問,每個詢問都可以讓L偏移O(sqrt(n)),所以L的總偏移量O(m*sqrt(n))。

注意,R的偏移量和詢問數目沒有直接關系。而L則恰恰相反;L的偏移量我們剛才也說明了,和塊的個數沒有直接關系。所以總的時間復雜度是:O((n+m)*sqrt(n))。

還有道莫隊經典題也分享一下:bzoj2038

大意是詢問區間內任意選兩個數為同一個數的概率並化為最簡分數。

設在某一區間內共有顏色a1,a2,a3...an,每雙襪子的個數為b1,b2,b3...bn

答案為(\sum_{i=1}^{n}{b_{i}(b_{i}-1)/2} )/((R-L+1)(R-L)/2)

化簡(\sum_{i=1}^{n}{b_{i}^{2} }-b)/((R-L+1)(R-L)/2)

((\sum_{i=1}^{n}{b_{i}^{2} })-(R-L+1))/((R-L+1)(R-L)/2)

所以只需要用莫隊處理每個區間內不同數字的平方和就好了。

分塊寫法:

 1 //分塊
 2 #include <bits/stdc++.h>
 3 /* define */
 4 #define ll long long
 5 #define dou double
 6 #define pb emplace_back
 7 #define mp make_pair
 8 #define fir first
 9 #define sec second
10 #define sot(a,b) sort(a+1,a+1+b)
11 #define rep1(i,a,b) for(int i=a;i<=b;++i)
12 #define rep0(i,a,b) for(int i=a;i<b;++i)
13 #define repa(i,a) for(auto &i:a)
14 #define eps 1e-8
15 #define int_inf 0x3f3f3f3f
16 #define ll_inf 0x7f7f7f7f7f7f7f7f
17 #define lson curPos<<1
18 #define rson curPos<<1|1
19 /* namespace */
20 using namespace std;
21 /* header end */
22 
23 const int maxn = 5e4 + 10;
24 int n, m, unit, num[maxn], a[maxn];
25 struct Interval
26 {
27     int l, r, id;
28     bool operator<(const Interval &rhs)const
29     {
30         if (l / unit != rhs.l / unit) return l / unit < rhs.l / unit;
31         else return r < rhs.r;
32     }
33 } interval[maxn];
34 struct Ans
35 {
36     ll a, b;
37     void reduce()
38     {
39         ll g = __gcd(a, b);
40         a /= g, b /= g;
41     }
42 } ans[maxn];
43 
44 void solve()
45 {
46     ll tmp = 0;
47     rep0(i, 0, maxn) num[i] = 0;
48     int l = 1, r = 0; //初始區間
49     rep1(i, 1, m)
50     {
51         while (r < interval[i].r)
52         {
53             r++;
54             tmp -= (ll)num[a[r]] * num[a[r]];
55             num[a[r]]++;
56             tmp += (ll)num[a[r]] * num[a[r]];
57         }
58         while (r > interval[i].r)
59         {
60             tmp -= (ll)num[a[r]] * num[a[r]];
61             num[a[r]]--;
62             tmp += (ll)num[a[r]] * num[a[r]];
63             r--;
64         }
65         while (l < interval[i].l)
66         {
67             tmp -= (ll)num[a[l]] * num[a[l]];
68             num[a[l]]--;
69             tmp += (ll)num[a[l]] * num[a[l]];
70             l++;
71         }
72         while (l > interval[i].l)
73         {
74             l--;
75             tmp -= (ll)num[a[l]] * num[a[l]];
76             num[a[l]]++;
77             tmp += (ll)num[a[l]] * num[a[l]];
78         }
79         ans[interval[i].id].a = tmp - (r - l + 1);
80         ans[interval[i].id].b = (ll)(r - l + 1) * (r - l);
81         ans[interval[i].id].reduce();
82     }
83 }
84 
85 int main()
86 {
87     scanf("%d%d", &n, &m);
88     rep1(i, 1, n) scanf("%d", &a[i]);
89     rep1(i, 1, m)
90     {
91         interval[i].id = i; scanf("%d%d", &interval[i].l, &interval[i].r);
92     }
93     unit = (int)sqrt(n);
94     sot(interval, m);
95     solve();
96     rep1(i, 1, m) printf("%lld/%lld\n", ans[i].a, ans[i].b);
97     return 0;
98 }
View Code

曼哈頓MST寫法:

  1 #include <cstdio>
  2 #include <cstdlib>
  3 #include <algorithm>
  4 #define N 50000
  5 #define Q 50000
  6 #define INFI 123456789
  7 
  8 typedef long long ll;
  9 struct edge
 10 {
 11     int next, node;
 12 } e[Q << 1 | 1];
 13 int head[N + 1], tot = 0;
 14 struct point
 15 {
 16     int x, y, n;
 17     bool operator < (const point &p) const
 18     {
 19         return x == p.x ? y < p.y : x < p.x;
 20     }
 21 } interval[Q + 1], p[Q + 1];
 22 struct inedge
 23 {
 24     int a, b, w;
 25     bool operator < (const inedge &x) const
 26     {
 27         return w < x.w;
 28     }
 29 } ie[Q << 3 | 1];
 30 int cnt = 0;
 31 struct BITnode
 32 {
 33     int w, p;
 34 } arr[Q + 1];
 35 int n, q, col[N + 1], a[Q + 1], *l[Q + 1], f[N + 1], c[N + 1];
 36 ll cur, ans[Q + 1];
 37 bool v[Q + 1];
 38 
 39 template <typename T>
 40 inline T abs(T x)
 41 {
 42     return x < (T)0 ? -x : x;
 43 }
 44 
 45 inline int dist(const point &a, const point &b)
 46 {
 47     return abs(a.x - b.x) + abs(a.y - b.y);
 48 }
 49 
 50 inline void addinedge(int a, int b, int w)
 51 {
 52     ++cnt;
 53     ie[cnt].a = a, ie[cnt].b = b, ie[cnt].w = w;
 54 }
 55 
 56 inline void addedge(int a, int b)
 57 {
 58     e[++tot].next = head[a];
 59     head[a] = tot, e[tot].node = b;
 60 }
 61 
 62 inline bool cmp(int *a, int *b)
 63 {
 64     return *a < *b;
 65 }
 66 
 67 inline int query(int x)
 68 {
 69     int r = INFI, p = -1;
 70     for (; x <= q; x += x & -x)
 71         if (arr[x].w < r) r = arr[x].w, p = arr[x].p;
 72     return p;
 73 }
 74 
 75 inline void modify(int x, int w, int p)
 76 {
 77     for (; x > 0; x -= x & -x)
 78         if (arr[x].w > w) arr[x].w = w, arr[x].p = p;
 79 }
 80 
 81 int find(int x)
 82 {
 83     return x == f[x] ? x : f[x] = find(f[x]);
 84 }
 85 
 86 inline ll calc(int x)
 87 {
 88     return (ll)x * (x - 1);
 89 }
 90 
 91 inline void add(int l, int r)
 92 {
 93     for (int i = l; i <= r; ++i)
 94     {
 95         cur -= calc(c[col[i]]);
 96         cur += calc(++c[col[i]]);
 97     }
 98 }
 99 
100 inline void remove(int l, int r)
101 {
102     for (int i = l; i <= r; ++i)
103     {
104         cur -= calc(c[col[i]]);
105         cur += calc(--c[col[i]]);
106     }
107 }
108 
109 void dfs(int x, int l, int r)
110 {
111     v[x] = true;
112     //Process right bound
113     if (r < interval[x].y)
114         add(r + 1, interval[x].y);
115     else if (r > interval[x].y)
116         remove(interval[x].y + 1, r);
117     //Process left bound
118     if (l < interval[x].x)
119         remove(l, interval[x].x - 1);
120     else if (l > interval[x].x)
121         add(interval[x].x, l - 1);
122     ans[x] = cur;
123     //Moving on to next query
124     for (int i = head[x]; i; i = e[i].next)
125     {
126         if (v[e[i].node]) continue;
127         dfs(e[i].node, interval[x].x, interval[x].y);
128     }
129     //Revert changes
130     //Process right bound
131     if (r < interval[x].y)
132         remove(r + 1, interval[x].y);
133     else if (r > interval[x].y)
134         add(interval[x].y + 1, r);
135     //Process left bound
136     if (l < interval[x].x)
137         add(l, interval[x].x - 1);
138     else if (l > interval[x].x)
139         remove(interval[x].x, l - 1);
140 }
141 
142 int main()
143 {
144     //Initialize
145     scanf("%d%d", &n, &q);
146     for (int i = 1; i <= n; ++i) scanf("%d", col + i);
147     for (int i = 1; i <= q; ++i) scanf("%d%d", &interval[i].x, &interval[i].y);
148     //Manhattan MST
149     for (int i = 1; i <= q; ++i) p[i] = interval[i], p[i].n = i;
150     for (int dir = 1; dir <= 4; ++dir)
151     {
152         //Coordinate transform
153         if (dir == 2 || dir == 4)
154             for (int i = 1; i <= q; ++i) std::swap(p[i].x, p[i].y);
155         else if (dir == 3)
156             for (int i = 1; i <= q; ++i) p[i].x = -p[i].x;
157         //Sort by x-coordinate
158         std::sort(p + 1, p + q + 1);
159         //Discretize
160         for (int i = 1; i <= q; ++i) a[i] = p[i].y - p[i].x, l[i] = &a[i];
161         std::sort(l + 1, l + q + 1, cmp);
162         int cnt = 1;
163         for (int i = 2; i <= q; ++i)
164             if (*l[i] == *l[i - 1]) *l[i - 1] = cnt;
165             else *l[i - 1] = cnt++;
166         *l[q] = cnt;
167         //Initialize BIT
168         for (int i = 1; i <= q; ++i) arr[i].w = INFI, arr[i].p = -1;
169         //Find points and add edges
170         for (int i = q; i > 0; --i)
171         {
172             int pos = query(a[i]);
173             if (pos != -1)
174                 addinedge(p[i].n, p[pos].n, dist(p[i], p[pos]));
175             modify(a[i], p[i].x + p[i].y, i);
176         }
177     }
178     //Kruskal
179     std::sort(ie + 1, ie + cnt + 1);
180     //Initialize disjoint set
181     for (int i = 1; i <= q; ++i) f[i] = i;
182     //Add edges
183     for (int i = 1; i <= cnt; ++i)
184         if (find(ie[i].a) != find(ie[i].b))
185         {
186             f[find(ie[i].a)] = find(ie[i].b);
187             addedge(ie[i].a, ie[i].b), addedge(ie[i].b, ie[i].a);
188         }
189 
190     //Modui Algorithm
191     ++c[col[1]];
192     dfs(1, 1, 1);
193     //Output
194     for (int i = 1; i <= q; ++i)
195     {
196         ll x = ans[i], y = calc(interval[i].y - interval[i].x + 1);
197         if (!x) printf("0/1\n");
198         else
199         {
200             ll g = __gcd(x, y);
201             printf("%lld/%lld\n", x / g, y / g);
202         }
203     }
204     return 0;
205 }
View Code

 

一些相關的題目:

bzoj 3289, 3236, 3585, 2120, 1878

SPOJ D-query

hdu 6278

zoj 4008

poj 2104 (然而這題肯定主席樹更清晰)


 Reference:

莫隊算法 (Mo's Algorithm):  https://zhuanlan.zhihu.com/p/25017840 

曼哈頓距離最小生成樹與莫隊算法: https://blog.csdn.net/huzecong/article/details/8576908


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM