原题链接:递增三元组
题目描述
给定三个整数数组
A
=
[
A
1
,
A
2
,
…
A
N
]
,
A=[A_1,A_2,…A_N],
A=[A1,A2,…AN],
B
=
[
B
1
,
B
2
,
…
B
N
]
,
B=[B_1,B_2,…B_N],
B=[B1,B2,…BN],
C
=
[
C
1
,
C
2
,
…
C
N
]
,
C=[C_1,C_2,…C_N],
C=[C1,C2,…CN],
请你统计有多少个三元组 ( i , j , k ) (i,j,k) (i,j,k) 满足:
- 1 ≤ i , j , k ≤ N 1≤i,j,k≤N 1≤i,j,k≤N
- A i < B j < C k Ai<Bj<Ck Ai<Bj<Ck
输入格式
第一行包含一个整数 N N N。
第二行包含 N N N 个整数 A 1 , A 2 , … A N 。 A_1,A_2,…A_N。 A1,A2,…AN。
第三行包含 N N N 个整数 B 1 , B 2 , … B N 。 B_1,B_2,…B_N。 B1,B2,…BN。
第四行包含 N N N 个整数 C 1 , C 2 , … C N 。 C_1,C_2,…C_N。 C1,C2,…CN。
输出格式
一个整数表示答案。
数据范围
1
≤
N
≤
105
,
1≤N≤105,
1≤N≤105,
0
≤
A
i
,
B
i
,
C
i
≤
105
0≤Ai,Bi,Ci≤105
0≤Ai,Bi,Ci≤105
输入样例1:
3
1 1 1
2 2 2
3 3 3
输出样例1:
27
思路
我自己想的时候是枚举 a a a 组中的每一个元素,然后二分出 b b b 和 c c c 中的边界点,然后相乘。但是自己想错了, 因为 b b b 和 c c c 有关系,所以不能直接相乘。
那么其实可以枚举 b b b 中的每一个点,再根据 b i b_i bi 找到 a , c a,c a,c 中小于 b i bi bi 的数 和 大于 b i bi bi 的数 ,因为如此分, a a a 和 c c c 就是独立存在,可以用乘法原理
如果暴力解决,那么必然是需要 O ( n 3 ) O(n^3) O(n3)的复杂度枚举三个点再判断其大小,必定超时
所以可以想办法优化,可以看出我们找小于和大于 b i b_i bi 的数时,可以
- 排序 + 二分找到边界点
- 前缀和求小于 b i b_i bi 的个数和 大于 b i b_i bi 的个数
- 双指针
排序 + 二分
对三个数组排序,枚举 b b b 的每一个点,然后二分出 a a a 中小于 b i b_i bi 的第一个数和 b b b 中大于 b i b_i bi 的第一个数,算出个数相乘即可
时间复杂度
O ( n l o g n ) O(nlogn) O(nlogn)
C++ 代码
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
int q[3][N];
int n;
int lowbound(int i,int x) {
int l = 0, r = n - 1;
while(l < r) {
int mid = l + r + 1>> 1;
if(q[i][mid] < x) l = mid;
else r = mid - 1;
}
if(q[i][l] >= x) l = -1;
return l;
}
int upbound(int i,int x) {
int l = 0, r = n - 1;
while(l < r) {
int mid = l + r >> 1;
if(q[i][mid] > x) r = mid;
else l = mid + 1;
}
if(q[i][r] <= x) r = n;
return r;
}
int main()
{
cin >> n;
for(int i = 0;i < 3;i ++) {
for(int j = 0;j < n;j ++) scanf("%d",&q[i][j]);
sort(q[i],q[i] + n);
}
LL res = 0;
for(int i = 0;i < n;i ++) {
int x = q[1][i];
int l1 = lowbound(0,x);
int l2 = upbound(2,x);
res += (LL)(l1 + 1) * (n - l2);
}
cout << res << endl;
return 0;
}
前缀和
时间复杂度
O ( N ) O(N) O(N)
C++ 代码
#include <cstring>
#include <iostream>
using namespace std;
typedef long long LL;
const int N = 100010;
int a[N], b[N], c[N];
int as[N]; // as[i]表示小于b[i]的数的个数
int cs[N]; // cs[i]表示大于b[i]的数的个数
int cnt[N], s[N];
int n;
int main()
{
scanf("%d",&n);
for(int i = 0;i < n;i ++) scanf("%d",&a[i]), a[i] ++;
for(int i = 0;i < n;i ++) scanf("%d",&b[i]), b[i] ++;
for(int i = 0;i < n;i ++) scanf("%d",&c[i]), c[i] ++;
for(int i = 0;i < n;i ++) cnt[a[i]] ++;
for(int i = 1;i < N;i ++) s[i] = s[i - 1] + cnt[i];
// 求as
for(int i = 0;i < n;i ++) as[i] = s[b[i] - 1];
memset(cnt, 0, sizeof cnt);
memset(s, 0, sizeof s);
// 求cs
for(int i = 0;i < n;i ++) cnt[c[i]] ++;
for(int i = 1;i < N;i ++) s[i] = s[i - 1] + cnt[i];
for(int i = 0;i < n;i ++) cs[i] = s[N - 1] - s[b[i]];
LL res = 0;
for(int i = 0;i < n;i ++) res += (LL) as[i] * cs[i];
cout << res << endl;
return 0;
}