给定三个整数数组
A=[A1,A2,…AN]
B=[B1,B2,…BN]
C=[C1,C2,…CN]请你统计有多少个三元组 (i,j,k) 满足:
- 1≤i,j,k≤N
- Ai<Bj<Ck
输入格式
第一行包含一个整数 N。
第二行包含 N 个整数 A1,A2,…AN
第三行包含 N个整数 B1,B2,…BN
第四行包含 N 个整数 C1,C2,…CN
输出格式
一个整数表示答案。
数据范围
1≤N≤105
0≤Ai,Bi,Ci≤105输入样例:
3 1 1 1 2 2 2 3 3 3
输出样例:
27
暴力n^3超时,以中间b为标准二分,分别求出某个b在c里有多少个大于这个b的,在a里有多少个小于b的这样就是nlogn
#include <iostream>
#include <algorithm>
using namespace std;
constexpr int N=1e5+7;
int n,a[N],b[N],c[N];
long long ans[N];
long long res;
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
for(int i=1;i<=n;i++){
scanf("%d",&b[i]);
}
for(int i=1;i<=n;i++){
scanf("%d",&c[i]);
}
sort(a+1,a+n+1);
sort(b+1,b+n+1);
sort(c+1,c+n+1);
for(int i=1;i<=n;i++){
int l=0,r=n;
while(l<r){
int mid=(l+r+1)>>1;
if(a[mid]<b[i]) l=mid;
else r=mid-1;
}
if(a[1]>=b[i]) ans[i]=0;
else ans[i]=l;
}
for(int i=1;i<=n;i++){
int l=0,r=n;
while(l<r){
int mid=(l+r)>>1;
if(c[mid]>b[i]) r=mid;
else l=mid+1;
}
if(c[n]<=b[i]) ans[i]=0;
else ans[i]=ans[i]*(long long)(n-l+1);
}
for(int i=1;i<=n;i++){
res+=ans[i];
}
printf("%lld",res);
}