統計逆序對的兩種解法
歸並排序(mergeSort)
逆序對定義
\(i<j\) 但\(a[i]>a[j]\),假設我們分別使得通過mergeSort使得左右半邊有序
即\(a[1]...a[mid]\) 遞增, \(a[mid+1]....a[n]\)遞增,我們需要通過merge操作,完成整個的排序和新增逆序對的計數,較小值出現在左半邊記為 a[i],出現在右半邊即為 a[j],那么每次出現在右半邊,意味左半邊比a[i]大的數都比a[j]大,由此可以統計逆序對
HDU1394
代碼實現
#include<bits/stdc++.h>
using namespace std;
#define db(x) cout<<"["<<#x<<"]="<<x<<endl
/*
歸並排序求逆序對+規律
*/
const int maxn = 5010;
int a[maxn];
int c[maxn];
int b[maxn];
//mergeSort
int n;
int merge1(int* a,int l1,int r1,int l2,int r2){
int p1=l1,p2 = l2;
int t = 0;
int cnt = 0;
//db(a[p1]);db(a[p2]);
while(p1<=r1&&p2<=r2){
if(a[p1]<a[p2]){
b[t] = a[p1];
p1++;
t++;
}
else{//a[p1]>a[p2]; a[p2] 小於 p1...r1所有數
b[t] = a[p2];
cnt+=(r1-p1+1);
//db(cnt);
p2++;
t++;
}
}
while(p1<=r1){b[t]=a[p1];p1++,t++;}
while(p2<=r2){b[t]=a[p2];p2++,t++;}
for(int k=0;k<t;k++){
a[l1+k] = b[k];
}
//db(cnt);db(l1);db(r1);db(l2);db(r2);
return cnt;
}
int mergeSort(int* a,int l,int r){
if(l==r) return 0;
int cnt = 0;
int mid = (l+r)>>1;
cnt+=mergeSort(a,l,mid);
cnt+=mergeSort(a,mid+1,r);
cnt+=merge1(a,l,mid,mid+1,r);
return cnt;
}
int main(){
while(cin>>n){
for(int i=0;i<n;i++){cin>>a[i];c[i]=a[i];}
int tmp = mergeSort(a,0,n-1);
//db(tmp);
int mint = tmp;
for(int i=0;i<n-1;i++){
tmp +=n-1-2*c[i];
//db(tmp);
mint = min(tmp,mint);
}
cout<<mint<<endl;
}
return 0;
}
線段樹
線段樹的解法非常簡單,每次插入a[i] ,同時對a[i]+1....n-1進行計數;
此時要求元素范圍不能太大,當然如果是在\(1..n\)之間,那么非常理想
代碼實現
#include<bits/stdc++.h>
using namespace std;
#define db(x) cout<<"["<<#x<<"]="<<x<<endl
const int maxn = 5e3+10;
struct node{
int l,r,num; //num維護的信息是節點插入的區間插入節點的數目
}tr[maxn<<2];//線段樹
int a[maxn];
void build(int n,int x,int y){//n是根節點下標,x,y是維護的區間范圍
tr[n].l = x,tr[n].r = y;
tr[n].num = 0;
if(x==y) return ;
int mid = (x+y)>>1;//no over
build(n<<1,x,mid);
build(n<<1|1,mid+1,y);
tr[n].num = tr[n<<1].num+tr[n<<1|1].num;
}
void modify(int n,int p){//跟新區間單點p的信息
int l = tr[n].l, r= tr[n].r;
if(l==r&&l==p){//found
tr[n].num=1;
return ;
}
int mid = (l+r)>>1;
if(p<=mid) modify(n<<1,p);
if(p>mid) modify(n<<1|1,p);
tr[n].num = tr[n<<1].num+tr[n<<1|1].num;
}
int query(int n,int x,int y){
int l = tr[n].l , r=tr[n].r;
int mid = (l+r)>>1;
int ans = 0;
if(l>=x&&r<=y){//x,y覆蓋了l,r
return tr[n].num;
}
if(x<=mid) ans+=query(n<<1,x,y);
if(y>mid) ans+=query(n<<1|1,x,y);
return ans;
}
int n;
int main(){
while(cin>>n){
build(1,0,n-1);
int ans = 0;
for(int i=0;i<n;i++){
cin>>a[i];
int t=query(1,a[i]+1,n-1);
//db(t);
ans+=t;
modify(1,a[i]);
}
int mint = ans;
//db(ans);
for(int i=0;i<n-1;i++){
ans+=(n-1-2*a[i]);
mint = min(ans,mint);
}
cout<<mint<<endl;
}
return 0;
}