题目描述
给定三个整数数组:A = [A1, A2, … AN],B = [B1, B2, … BN],C = [C1, C2, … CN],
请你统计有多少个三元组 (i, j, k) 满足:
- 1 ≤ i, j, k ≤ N
- Ai < Bj < Ck
输入格式
第一行包含一个整数 N
第二行包含 N 个整数 A1, A2, … AN
第三行包含 N 个整数 B1, B2, … BN
第四行包含 N 个整数 C1, C2, … CN
输出格式
一个整数表示答案
样例输入
3
1 1 1
2 2 2
3 3 3
样例输出
27
数据范围
对于 30% 的数据,1 ≤ N ≤ 100
对于 60% 的数据,1 ≤ N ≤ 1000
对于 100% 的数据,1 ≤ N ≤ 105,0 ≤ Ai, Bi, Ci ≤ 105
题解一(超时)
暴力枚举:
O
(
N
3
)
O(N^3)
O(N3)
#include <iostream>
using namespace std;
const int N = 100010;
int n;
int a[N], b[N], c[N];
int main()
{
cin >> n;
for (int i = 0; i < n; i ++) scanf("%d", &a[i]);
for (int i = 0; i < n; i ++) scanf("%d", &b[i]);
for (int i = 0; i < n; i ++) scanf("%d", &c[i]);
int ans = 0;
for (int i = 0; i < n; i ++)
for (int j = 0; j < n; j ++)
for (int k = 0; k < n; k ++)
if(b[j] > a[i] && b[j] < c[k])
ans ++;
cout << ans << endl;
return 0;
}
题解二
双指针:
O
(
N
l
o
g
N
)
O(NlogN)
O(NlogN)
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
int n;
int a[N], b[N], c[N];
int main()
{
cin >> n;
for (int i = 1; i <= n; i ++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i ++) scanf("%d", &b[i]);
for (int i = 1; i <= n; i ++) scanf("%d", &c[i]);
sort(a + 1, a + 1 + n);
sort(b + 1, b + 1 + n);
sort(c + 1, c + 1 + n);
LL ans = 0;
int L = 1, R = 1;
for (int i = 1; i <= n; i ++)
{
while(L <= n && a[L] < b[i]) L ++;
while(R <= n && c[R] <= b[i]) R ++;
ans += (LL) (L - 1) * (n - R + 1);
}
cout << ans << endl;
return 0;
}
题解三
二分:
O
(
N
l
o
g
N
)
O(NlogN)
O(NlogN)
#include <iostream>
#include <algorithm>
#include <cstdio>
using namespace std;
typedef long long LL;
const int N = 100010;
int n;
int a[N], b[N], c[N];
int find1(int x) // 在 A 数组中找出小于 b[i] 的第一个数
{
int l = 1, r = n;
while(l < r)
{
int mid = l + r + 1 >> 1;
if(a[mid] < x) l = mid;
else r = mid - 1;
}
return l;
}
int find2(int x) // 在 C 数组中找出大于 b[i] 的第一个数
{
int l = 1, r = n;
while(l < r)
{
int mid = l + r >> 1;
if(c[mid] > x) r = mid;
else l = mid + 1;
}
return l;
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i ++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i ++) scanf("%d", &b[i]);
for (int i = 1; i <= n; i ++) scanf("%d", &c[i]);
sort(a + 1, a + 1 + n);
sort(c + 1, c + 1 + n);
LL ans = 0;
for (int i = 1; i <= n; i ++)
{
int L = find1(b[i]), R = find2(b[i]);
if(a[L] < b[i] && c[R] > b[i]) ans += (LL) L * (n - R + 1);
}
cout << ans << endl;
return 0;
}
题解四
前缀和:
O
(
N
)
O(N)
O(N)
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
typedef long long LL;
const int N = 100010;
int n;
int s[N]; // 前缀和数组,s[i]:数字 [1 ~ i] 出现的总次数
int cnt[N]; // cnt[i]:统计数字 i 出现的次数
int L[N], R[N]; // L[i]:A数组中小于 b[i] 的元素个数; R[i]:C数组中大于 b[i] 的元素个数;
int a[N], b[N], c[N];
int main()
{
cin >> n;
for (int i = 0; i < n; i ++) scanf("%d", &a[i]), a[i] ++; // +1 是为了便于操作前缀和数组
for (int i = 0; i < n; i ++) scanf("%d", &b[i]), b[i] ++;
for (int i = 0; i < n; i ++) scanf("%d", &c[i]), c[i] ++;
for (int i = 0; i < n; i ++) cnt[a[i]] ++;
for (int i = 1; i < N; i ++) s[i] = s[i - 1] + cnt[i];
for (int i = 0; i < n; i ++) L[i] = s[b[i] - 1];
memset(s, 0, sizeof s);
memset(cnt, 0, sizeof cnt);
for (int i = 0; i < n; i ++) cnt[c[i]] ++;
for (int i = 1; i < N; i ++) s[i] = s[i - 1] + cnt[i];
for (int i = 0; i < n; i ++) R[i] = s[N - 1] - s[b[i]];
LL ans = 0;
for (int i = 0; i < n; i ++)
ans += (LL) L[i] * R[i];
cout << ans << endl;
return 0;
}