给定三个整数数组
A = [A1, A2, ... AN],
B = [B1, B2, ... BN],
C = [C1, C2, ... CN],
请你统计有多少个三元组(i, j, k) 满足:
1. 1 <= i, j, k <= N
2. Ai < Bj < Ck
输入格式
第一行包含一个整数N。
第二行包含N个整数A1, A2, ... AN。
第三行包含N个整数B1, B2, ... BN。
第四行包含N个整数C1, C2, ... CN。
对于30%的数据,1 <= N <= 100
对于60%的数据,1 <= N <= 1000
对于100%的数据,1 <= N <= 100000 0 <= Ai, Bi, Ci <= 100000
输出格式
一个整数表示答案
样例输入
3
1 1 1
2 2 2
3 3 3
样例输出
27
思路:对于数组B从1到n,只需用数组A中小于B[j]的个数乘以数组C中大于B[i]的个数即可,累加即可
#include<iostream>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
ll a[1000005];
ll b[1000005];
ll c[1000005];
int main() {
ll n, s1 = 0, s2 = 0, i, k;
ll ans = 0;
cin >> n;
i = 1;
k = 1;
for (ll qq = 1; qq <= n; qq++)
{
cin >> a[qq];
}
for (ll qq = 1; qq <= n; qq++)
{
cin >> b[qq];
} for (ll qq = 1; qq <= n; qq++)
{
cin >> c[qq];
}
sort(a + 1, a + n + 1);
sort(b + 1, b + n + 1);
sort(c + 1, c + n + 1);
for (ll j = 1; j <= n; j++)
{ //减少了不必要的循环(因为数组都是递增的)
for (; i <= n; i++)
{
if (a[i] < b[j])
{
s1++;
}
else
{
break;
}
}
for (;; k++)
{
if (b[j] < c[k])
{
s2 = n - k + 1;
break;
}
if (k == n)//此时已经遍历完整个c[]了,依然无法找到比b[j]大的数
{
s2 = 0;
break;
}
}
if (s2 == 0 || s1 == 0)continue;//只要有为0的情况必定无法组成三元组
ans += (s1 * s2);
}
cout << ans;
return 0;
system("pause");
}