问题描述
给定三个整数数组
A=[A1,A2,…AN]
B=[B1,B2,…BN]
C=[C1,C2,…CN]请你统计有多少个三元组 (i,j,k) 满足:
- 1≤i,j,k≤N
- Ai<Bj<Ck
输入格式
第一行包含一个整数 N。
第二行包含 N 个整数 A1,A2,…AN。
第三行包含 N 个整数 B1,B2,…BN。
第四行包含 N 个整数 C1,C2,…CN。
输出格式
一个整数表示答案
数据范围
1≤N≤10^5
0≤Ai,Bi,Ci≤10^5
输入样例
3
1 1 1
2 2 2
3 3 3
输出样例
27
题目分析:
我们容易想到,纯暴力的做法为,直接三重循环遍历三个数组,但是时间复杂度高达O(n^ 3),由题目中数据范围为10^5可知,纯暴力的方法不可取,会超时。
由数据范围10^5,我们可以推测此题的正解复杂度为O(n) 或 O(nlogn),且大概率为后者。
在上面的基础上,我们容易发现在优化后的解法中,可以遍历的层数最多为1层,也就是说,我们只能遍历三个数组中的一个。题目要找的解为满足Ai<Bj<Ck的解,所以A和C的性质类似,两者相互独立。如果我们遍历,那么计算出A和C 后就可以通过乘法原理计算answer。而遍历A的话,由于B和C不是相互独立的,所以计算更为复杂。于是,我们选择遍历B,去计算A和C。
计算A和C的方法由多种,例如我们可以通过前缀和的方法,求得在A中小于B[ i ]的值的个数,同理,也能求在C中大于B[ i ]的值的个数, 时间复杂度为O(n)。另外,我们还可以通过将A与C进行sort + 二分的过程求解,时间复杂度为O(nlogn)
解法一:
//前缀和
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N = 100010;
int a[N], b[N], c[N];
int as[N], cs[N];
int s[N], cnt[N];
int n;
int main(){
//读入
cin >> n;
for(int i = 0; i < n; i ++ ) scanf("%d", &a[i]), a[i] ++ ;
for(int i = 0; i < n; i ++ ) scanf("%d", &b[i]), b[i] ++ ;
for(int i = 0; i < n; i ++ ) scanf("%d", &c[i]), c[i] ++ ;
//构建前缀和as[]与cs[]
for(int i = 0; i < n; i ++ ) cnt[a[i]] ++ ;
for(int i = 1; i < N; i ++ ) s[i] = s[i - 1] + cnt[i];
for(int i = 0; i < n; i ++ ) as[i] = s[b[i] - 1];
memset(cnt, 0, sizeof(cnt));
memset(s, 0, sizeof(s));
for(int i = 0; i < n; i ++ ) cnt[c[i]] ++ ;
for(int i = 1; i < N; i ++ ) s[i] = s[i - 1] + cnt[i];
for(int i = 0; i < n; i ++ ) cs[i] = s[N - 1] - s[b[i]];
//answer的计算
long long ans = 0;
for(int i = 0; i < n; i ++ ){
long long tem = (long long) as[i] * cs[i];
ans += tem;
}
cout << ans << endl;
return 0;
}
解法二:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long lld;
const int N = 100005;
int a[N], b[N], c[N];
int n;
lld sum;
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for (int i = 1; i <= n; i++)
scanf("%d", &b[i]);
for (int i = 1; i <= n; i++)
scanf("%d", &c[i]);
//由于二分的前提是单调序列 所以预先对a b c排序 直接sort
sort(a + 1, a + 1 + n);
sort(b + 1, b + 1 + n);
sort(c + 1, c + 1 + n);
for (int i = 1; i <= n; i++)
{
//直接用STL中的两个二分函数解决
lld x = (lower_bound(a + 1, a + 1 + n, b[i]) - a) - 1; //在数组a中找比b[i]小的数
lld y = n - (upper_bound(c + 1, c + 1 + n, b[i]) - c) + 1; //在数组c中找比b[i]大的数
sum += x * y;
}
printf("%lld", sum);
return 0;
}