標題:遞增三元組
給定三個整數數組
A = [A1, A2, ... AN],
B = [B1, B2, ... BN],
C = [C1, C2, ... CN],
請你統計有多少個三元組(i, j, k) 滿足:
- 1 <= i, j, k <= N
- Ai < Bj < Ck
【輸入格式】
第一行包含一個整數N。
第二行包含N個整數A1, A2, ... AN。
第三行包含N個整數B1, B2, ... BN。
第四行包含N個整數C1, C2, ... CN。
對於30%的數據,1 <= N <= 100
對於60%的數據,1 <= N <= 1000
對於100%的數據,1 <= N <= 100000 0 <= Ai, Bi, Ci <= 100000
【輸出格式】
一個整數表示答案
【樣例輸入】
3
1 1 1
2 2 2
3 3 3
【樣例輸出】
27
資源約定:
峰值內存消耗(含虛擬機) < 256M
CPU消耗 < 1000ms
請嚴格按要求輸出,不要畫蛇添足地打印類似:“請您輸入...” 的多余內容。
注意:
main函數需要返回0;
只使用ANSI C/ANSI C++ 標准;
不要調用依賴於編譯環境或操作系統的特殊函數。
所有依賴的函數必須明確地在源文件中 #include
不能通過工程設置而省略常用頭文件。
提交程序時,注意選擇所期望的語言類型和編譯器類型。
思路:
遞增三元組,用三個數組表示;先排序sort,找到a數組中 第一個大於等於b[i]的數下標為j,找到c數組中 第一個比b[i]大的數下標為k,推導公式計算出結果。
AC代碼:
#include<iostream>
#include<algorithm>
using namespace std;
//4
//1 3 4 5
//1 2 2 2
//2 3 3 4
int a[100010];
int b[100010];
int c[100010];
int n;
long long ans = 0;
int main(){
//輸入數據
cin>>n;
for(int i=1;i<=n;i++){
cin>>a[i];
}
for(int i=1;i<=n;i++){
cin>>b[i];
}
for(int i=1;i<=n;i++){
cin>>c[i];
}
//排序
sort(a+1,a+n+1);
sort(b+1,b+n+1);
sort(c+1,c+n+1);
//定義兩個指針(下標)
int j = 1;
int k = 1;
//以b為中間值 在a數組 c數組中查找
for(int i=1;i<=n;i++){
while(j<=n && a[j] < b[i]) j++; //在a數組中查找第一個大於等於b[i]的數
while(k<=n && c[k] <= b[i]) k++; //在c數組中查找第一個大於b[i]的數
ans += (long long)(j-1) * (n-k+1); //計算公式 可以自己舉例推導出來
}
cout<<ans<<endl;
}
upper_bound 和 lower_bound也可以解決問題:
代碼
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cstdlib>
using namespace std;
typedef long long ll;
const int maxn=1e5+10;
int a[maxn],b[maxn],c[maxn];
int main(){
int n;
cin>>n;
for(int i=0; i<n; i++ ){
cin>>a[i];
}
for(int i=0; i<n; i++ ){
cin>>b[i];
}
for(int i=0; i<n; i++ ){
cin>>c[i];
}
sort(a,a+n);
sort(b,b+n);
sort(c,c+n);
ll cnt=0;
//以b為中間值
for(int i=0; i<n; i++ ){
ll pos1 = lower_bound(a,a+n,b[i])-a; //在a數組中查找比b大於等於的第一個數的指針
ll pos2 = upper_bound(c,c+n,b[i])-c; //在c數組中查找比b大的最第一個數的指針
cnt += (ll)pos1*(n-pos2);
}
cout<<cnt<<endl;
return 0;
}
下面是錯誤代碼,不能以a為基准 在b、c中查找。會出現b和c中元素重復情況。解決方法是 以b數組為基准 在a、c數組中查找
代碼:
#include<iostream>
#include<algorithm>
using namespace std;
//4
//1 3 4 5
//1 2 2 2
//2 3 3 4
int a[100010];
int b[100010];
int c[100010];
int n;
long long ans = 0;
int main(){
//輸入數據
cin>>n;
for(int i=1;i<=n;i++){
cin>>a[i];
}
for(int i=1;i<=n;i++){
cin>>b[i];
}
for(int i=1;i<=n;i++){
cin>>c[i];
}
//排序
sort(a+1,a+n+1);
sort(b+1,b+n+1);
sort(c+1,c+n+1);
//定義兩個指針(下標)
int j = 1;
int k = 1;
for(int i=1;i<=n;i++){
//找到b數組中 第一個比a[i]大的數
while(j<=n && b[j]<a[i]){
j++;
}
// cout<<"j="<<j;
//找到c數組中 第一個比a[i]大的數
while(k<=n && c[k]<a[i]){
k++;
}
// cout<<" k="<<k<<endl;
if(j<=n && k<=n){
ans += (n-j+1) * (n-k+1);//計算公式:b中j后面的數都比a[i]大 k同理 組數=jk相乘
}
}
cout<<ans<<endl;
}