题目:
给定三个整数数组 A = [A1, A2, … AN], B = [B1, B2, … BN], C = [C1, C2,
… CN], 请你统计有多少个三元组(i, j, k) 满足:
i1 <= i, j, k <= N
Ai < Bj < Ck
【输入格式】 第一行包含一个整数N。 第二行包含N个整数A1, A2, … AN。 第三行包含N个整数B1, B2, … BN。
第四行包含N个整数C1, C2, … CN。
对于30%的数据,1 <= N <= 100
对于60%的数据,1 <= N <= 1000
对于100%的数据,1 <= N <=100000 0 <= Ai, Bi, Ci <= 100000
【输出格式】 一个整数表示答案
【输入样例】
3
1 1 1
2 2 2
3 3 3
【输出样例】
27
分析:
首先,我会想到将abc序列分别排序(可以直接使用algorithm库里的sort函数),排序后,遍历a序列,依次取出a序列中的各个元素,每每取出一个元素,分别遍历b序列和c序列,找到符合条件:
i1 <= i, j, k <= N
Ai < Bj < Ck
的序列。然而这种方法耗费的时间太长,如果序列元素个数很大,会超时。所以,采用另一种更省时的方法:各个序列先排序,排序后,以b序列为参照,挨个遍历b序列,记b[i],此时,分别找a序列中小于b[i]有多少个数,记p,c序列中有多少大于b[i]序列的数,记q。故,遍历到b[i]时,有p*q种符合条件的序列。
代码如下:
#include<iostream>
#include<algorithm>
using namespace std;
int main(){
int N;
cin>>N;
int a[N],b[N],c[N];
for(int i = 0;i < N;++i){
cin>>a[i];
}
for(int i = 0;i < N;++i){
cin>>b[i];
}
for(int i = 0;i < N;++i){
cin>>c[i];
}
sort(a,a+N);
sort(b,b+N);
sort(c,c+N);
int p=0,q=0,sum=0;
for(int j=0;j<N;++j){
while(a[p]<b[j]&&p<N) ++p;
while(c[q]>=b[j]&&q<N) ++q;
sum+=p*q;
}
cout<<sum<<endl;
return 0;
}