分析:
這是一張完全圖,並且邊的權值是由點的權值$xor$得到的,所以我們考慮貪心的思想,考慮$kruskal$的過程選取最小的邊把兩個連通塊合並,所以我們可以模仿$kruskal$的過程,倒着做$kruskal$,設定當前的最高位為$d$,我們把點集分為兩個集合,$s$集合代表$d$位為$1$的點,$t$集合代表$d$位為$0$的點,就是$st$兩個連通塊,考慮這兩個連通塊的連接,把$t$連通塊建出一棵$trie$樹,然后枚舉$s$集合中的點,去查找最小邊,然后統計最小邊的數量,遞歸解決$st$兩個連通塊,最后統計方案數的時候就是乘法原理...
為什么按照每一位的$01$來划分集合?我們考慮現在把$s$拆成兩個連通塊,這樣一共有三個連通塊,如果按照貪心的思想,一定是先連接$s$的連通塊,因為最高位一定是$0$,這樣邊比較小...
需要注意的細節就是如果有很多相同的點,並且這張子圖是完全圖,那么這就是一個完全圖生成樹計數的問題,根據$prufer$可以得出點數為$n$的完全圖生成樹計數為$n^{n-2}$...證明請見:http://www.matrix67.com/blog/archives/682
代碼:
#include<algorithm> #include<iostream> #include<cstring> #include<cstdio> //by NeighThorn #define pa pair<int,int> #define inf 0x3f3f3f3f #define mp make_pair using namespace std; const int maxn=100000+5,mod=1e9+7; int n,tot,anscnt,a[maxn],s[maxn],t[maxn],fac[maxn]; long long sum; struct Trie{ int cnt,nxt[2]; }tr[maxn*30]; inline int read(void){ char ch=getchar();int x=0; while(!(ch>='0'&&ch<='9')) ch=getchar(); while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); return x; } inline void init(void){ for(int i=0;i<=tot;i++) tr[i].nxt[0]=tr[i].nxt[1]=tr[i].cnt=0; tot=0; } inline void insert(int x){ int p=0; for(int i=30,y;i>=0;i--){ y=(x>>i)&1; if(!tr[p].nxt[y]) tr[p].nxt[y]=++tot; p=tr[p].nxt[y]; } tr[p].cnt++; } inline pa find(int x){ int p=0,ans=0; for(int i=30,y;i>=0;i--){ y=(x>>i)&1; if(tr[p].nxt[y]) p=tr[p].nxt[y],ans|=y<<i; else p=tr[p].nxt[y^1],ans|=(y^1)<<i; } return mp(ans^x,tr[p].cnt); } inline int power(int x,int y){ int res=1; while(y){ if(y&1) res=1LL*res*x%mod; x=1LL*x*x%mod,y>>=1; } return res; } inline void solve(int l,int r,int dep){ if(l>=r) return; if(dep<0){ if(r-l+1>=2) anscnt=1LL*anscnt*power(r-l+1,r-l-1)%mod; return; } int cnt1=0,cnt2=0; for(int i=l;i<=r;i++) if((a[i]>>dep)&1) s[cnt1++]=a[i]; else t[cnt2++]=a[i]; for(int i=0;i<cnt1;i++) a[l+i]=s[i]; for(int i=0;i<cnt2;i++) a[l+cnt1+i]=t[i]; init();pa tmp;int ans=inf,cnt=0; for(int i=0;i<cnt2;i++) insert(t[i]); for(int i=0;i<cnt1;i++){ tmp=find(s[i]); if(tmp.first<ans) ans=tmp.first,cnt=tmp.second; else if(tmp.first==ans) cnt+=tmp.second; } if(sum!=inf&&cnt) sum+=ans,anscnt=1LL*cnt*anscnt%mod; solve(l,l+cnt1-1,dep-1);solve(l+cnt1,r,dep-1); } signed main(void){ n=read(),anscnt=1;fac[0]=1; for(int i=1;i<=n;i++) fac[i]=1LL*fac[i-1]*i%mod; for(int i=1;i<=n;i++) a[i]=read(); solve(1,n,30); printf("%lld\n%d\n",sum,anscnt); return 0; }
By NeighThorn