給定整數m以及n個數字A1,A2,...An,將數列A中所有元素兩兩異或,共能得到n(n-1)/2個結果,請求出這些結果中大於m的有多少個。
一看題目,感覺是trie樹,也沒搞清楚邏輯,就開始碼代碼,這樣的結果注定是失敗的!正確的做法是在紙上畫清楚,每一步應該怎么做,應該怎么考慮邊界條件,怎么搜索,拿最簡單的例子測試,然后自己再想一些邊界例子測試,最后才是碼代碼,除非你對這題很熟,或者感覺是水題,直接寫也可以。反正最后不會做,看了一下大神代碼,半天才理解,都是套路,其實這個性質一時半會分析不出來吧!
分析:暴力是不行的,1e5的數據范圍,暴力肯定超時,大方向肯定是用trie樹進行壓縮,然后是查找,關鍵的問題是,給定數字a,我們需要尋找b的格式,使得a^b>m,怎么從trie樹中尋找b的個數,就是解決這道題目的關鍵。
1. 我想說明下數據范圍,n,m,Ai都是[1,1e5]的,(1 << 17)>1e5,所以一個數至少要17位來存儲,所以trie樹的節點個數就是1e5*17,這個不理解的話,仔細查看一下trie樹的資料吧。
2. 然后是a^b>m,現在我們知道a和m,要查找b的個數,首先a和m可以簡單的表示成17位二進制01的形式,然后查找。查找的時候,以m為導向,我們盡量確保a^b以后,如果m相應位置為0,a^b相應位置為1的,肯定比m大(肯定大需要按照這里要求的方式從高位到低位進行枚舉),然后最后結果加上這些為1的即可!不管后面的位是什么情況,因為結果肯定是大於m的! 如果m相應的位置為1,我們需要a和b相應的位置不同,即一個為1,一個為0.到這里可能迷糊了,如果相同也能保證異或結果大於m啊,比如(m=01010,a=00110,這里簡單起見,一共5位,從左到右編號,m的第一位為0,如果a^b以后這位為1,顯然是滿足的,最后結果直接加上即可,不管后面的位是什么情況;然后考慮m的第二位為1,這里我需要保證a^b結果,相應位為1,然后往后才能找到滿足要求異或結果大於m的,當然,有人說結果是0也可以啊,但是前一種情況已經把這種情況考慮進去了,這里只需要考慮這里的結果位為1即可,也就是a和b相應的二進制的這一位不同)
3. 第二點有點難懂,下面可以結合代碼理解。我這里還要說明,什么時候統計結果,統計結果就是a和b的異或結果大於m,也就是m相應位置為1,結果就是異或結果相應位置為1,也就是異或結果盡量保證m二進制位上為1的位置盡量為1,(盡量為1的意思是:不一定一定為1,考慮上面的給的例子的第一種情況)。然后為0的位置至少有一個位置為1.最后的結果要用long long來存儲,n(n-1)/2,int可能會溢出。最后,(a,b)和(b,a)算一種情況,所以結果需要除以2.
友情提示:跟1異或,相當於該位取反。a^b=c,有a^c=b.
下面是牛客網上抄的,比我講的清楚。
異或那道題可以把每個數的二進制位求出來,用一個字典樹維護,然后遍歷每一個數按位貪心,比如這一位m是1,遍歷的這個數這一位是0,那么和他異或的數就必須是1,如果這一位m是0,要大於m的話異或和的這一位可以是1也可以是零,ans加上之前維護的二進制位加上使這一位為1的數在字典樹中查詢有多少個數滿足這個前綴的條件,然后在令這一位的異或和為0,繼續向下遍歷,最后的答案除以2.
貼上大神的代碼,膜拜一下!orz。
放幾個我認為有幫助的題目吧:
1. http://codeforces.com/problemset/problem/282/E 這個相關度最高,trie+xor
2. http://codeforces.com/contest/706/problem/D 這個也是trie+xor
3. http://codeforces.com/contest/714/problem/C 這個只有trie,這個題有點意思,還有不用trie樹的簡單做法。
4. https://threads-iiith.quora.com/Tutorial-on-Trie-and-example-problems 關於trie樹的一點知識點吧!
1 #include <cstdio> 2 #include <cstring> 3 4 const int N = 100010; 5 6 int a[N]; 7 8 struct node { 9 int count; 10 int next[2]; 11 }p[N*17], root; 12 13 int cnt = 0; 14 void insert(int *a, int len) { 15 int now = 0; 16 for (int i = 0; i < len; ++i) { 17 if (p[now].next[a[i]] == -1) { 18 cnt++; 19 p[cnt].next[0] = p[cnt].next[1] = -1; 20 p[cnt].count = 0; 21 p[now].next[a[i]] = cnt; 22 } 23 now = p[now].next[a[i]]; 24 p[now].count++; 25 } 26 } 27 28 typedef long long LL; 29 int query(int *a, int *b, int len) { 30 int now = 0; 31 int ret = 0; 32 for (int i = 0; now != -1 && i < len; ++i) { 33 if (b[i] == 0) { 34 if (p[now].next[1^a[i]] != -1) ret += p[p[now].next[1^a[i]]].count; 35 now = p[now].next[a[i]]; 36 } 37 else { 38 now = p[now].next[1^a[i]]; 39 } 40 } 41 return ret; 42 } 43 44 45 int main() { 46 int n, m; 47 while (scanf("%d%d", &n, &m) == 2) { 48 cnt = 0; 49 p[0].next[0] = p[0].next[1] = -1; 50 p[0].count = 0; 51 for (int i = 0; i < n; ++i) { 52 scanf("%d", &a[i]); 53 int tmp[18]; 54 for (int j = 0; j < 18; ++j) 55 tmp[j] = (a[i] >> (17 - j)) & 1; 56 insert(tmp, 18); 57 //for (int i = 0; i < 30; ++i) printf("%d ", p[i].count); 58 //puts("----"); 59 } 60 int kk[18]; 61 for (int j = 0; j < 18; ++j) 62 kk[j] = (m >> (17 - j)) & 1; 63 LL ret = 0; 64 for (int i = 0; i < n; ++i) { 65 int tmp[18]; 66 for (int j = 0; j < 18; ++j) 67 tmp[j] = (a[i] >> (17 - j)) & 1; 68 ret += query(tmp, kk, 18); 69 //printf("%d\n", ret); 70 } 71 printf("%lld\n", ret / 2); 72 } 73 return 0; 74 }
之前不小心加上了gist的鏈接,所以打開很慢!
摘自牛客網:
剛剛聽到另外一個方法...建好樹之后,把m轉成二進制,如果m當前位是0,直接把經過左右節點的數的個數相乘;如果m當前位是1,就分別從左右節點的子節點里選一個分支進行組合,遞歸調用,結果相加......
dfs里面的第一個if條件的原因,如果左右節點相等,i>j的情況不考慮,原因是:對於每一位,我們考慮(00,01,10,11),當當前節點相同的時候,01和10只需要計算一次即可,仔細想想!
1 /* 2 ID: y1197771 3 PROG: test 4 LANG: C++ 5 */ 6 #include<bits/stdc++.h> 7 #define pb push_back 8 #define FOR(i, n) for (int i = 0; i < (int)n; ++i) 9 #define dbg(x) cout << #x << " at line " << __LINE__ << " is: " << x << endl 10 typedef long long ll; 11 using namespace std; 12 typedef pair<int, int> pii; 13 const int maxn = 1e5 + 10; 14 struct node { 15 int next[2]; 16 int c; 17 node() { 18 memset(next, 0, sizeof next); 19 c = 0; 20 } 21 } A[maxn * 18]; 22 int num; 23 int n, m; 24 int tag[18]; 25 void insert(int x) { 26 int u = 0, cur = 0; 27 for (int i = 17; i >= 0; i--) { 28 cur = ((1 << i) & x) > 0; 29 if(!A[u].next[cur]) { 30 A[u].next[cur] = ++num; 31 } 32 u = A[u].next[cur]; 33 //cout << u << " " << x << endl; 34 A[u].c++; 35 } 36 } 37 ll dfs(int cur, int l, int r) { 38 39 if(cur < 0) return 0; 40 ll res = 0; 41 for (int i = 0; i <= 1; i++) { 42 for (int j = 0; j <= 1; j++) { 43 if(l == r && i > j) continue; 44 if(!A[A[l].next[i] ].c || !A[A[r].next[j] ].c) continue; 45 if((i ^ j) > tag[cur]) { 46 res += 1ll * A[A[l].next[i] ].c * A[A[r].next[j] ].c; 47 } else if((i ^ j) == tag[cur]) { 48 res += dfs(cur - 1, A[l].next[i], A[r].next[j]); 49 } 50 } 51 } 52 //cout << cur << " " << l << " " << r << " " << res << endl; 53 return res; 54 } 55 void solve() { 56 num = 0; 57 memset(A, 0, sizeof A); 58 cin >> n >> m; 59 int x; 60 for (int i = 0; i < n; i++) { 61 cin >> x; 62 insert(x); 63 } 64 for (int i = 17; i >= 0; i--) { 65 tag[i] = ((1 << i) & m) > 0; 66 } 67 printf("%I64d\n", dfs(17, 0, 0)); 68 } 69 int main() { 70 freopen("data", "r", stdin); 71 //freopen("test.out", "w", stdout); 72 int _ = 20; 73 while(_--) 74 solve(); 75 return 0; 76 }
