优化小技巧 acwing.1236递增三元组
平时我们用前缀和都是快速统计某个连续范围的和
所以我们可以利用这一点 对如下类型题目进行优化
给定三个整数数组
A=[A1,A2,…AN],
B=[B1,B2,…BN],
C=[C1,C2,…CN],
请你统计有多少个三元组 (i,j,k)
满足:
1≤i,j,k≤N
Ai<Bj<Ck
输入格式
第一行包含一个整数 N。
第二行包含 N个整数 A1,A2,…AN。
第三行包含 N个整数 B1,B2,…BN。
第四行包含 N个整数 C1,C2,…CN。
输出格式
一个整数表示答案。
数据范围
1≤N≤10^5,
0≤Ai,Bi,Ci≤10^5
输入样例:
3
1 1 1
2 2 2
3 3 3
输出样例:
27
首先我们理解题目的意思 是要找有多少满足 i<j<k且a[i]<b[j]<c[k]形式的abc三元组
我们用最简单的思路就会想到
for(枚举a数组)
for(枚举b数组)
for(枚举c数组)
判断满足三元组条件
res++
这样的话时间复杂度就来到了O(n^3)次方
题目的数据范围是1e5
这就提示我们 算法的复杂度要控制在O(n) O(nlogn) 左右
也就是说我们只能枚举一个数组
a b c中a c没有直接联系 所以我们可以枚举b (如果枚举a的话 b c 都大于a 后面性质不好推导)
我们枚举b 那么我们只需要找比b小的有多少个 比b大的有多少个 然后乘法原理 相乘就能得到某个b能构成多少个三元组
我们要能快速得出比b小有多少个 比b大有多少个
这时候就有了不同的优化思路
①用前缀和(因为我们看到数据范围是1~1e5 我们能开这样一个数组s[N] 表示1到n的数的个数一共有多少个 维护cnt[N]桶计数某个数有多少个 s[i]=s[i-1]+cnt[i])
那么比b[i]小的数共有多少个就是 s[b[i]-1]
比b[i]大的数个数用对立的思想计算 (所有数的个数和-1到b[i]的数个数和)
即 s[N-1]-s[b[i]]
下面请看代码
#include <iostream>
#include <cstring>
using namespace std;
const int N=100010;
typedef long long LL;
int b[N],as[N],cs[N],cnt[N],s[N];//b数组 as记录比b[i]小的总数,cs记录比b[i]大的总数
int main(void)
{
int n;
cin>>n;
for(int i=0;i<n;i++)
{
int x;
cin>>x;
cnt[++x]++;//把数据范围0~1e5平移成1~1e5+1 方便前缀和
}
for(int i=1;i<N;i++)s[i]=s[i-1]+cnt[i];
for(int i=1;i<=n;i++)cin>>b[i],b[i]++;//b数组跟着平移
for(int i=1;i<=n;i++)as[i]=s[b[i]-1];
memset(cnt,0,sizeof cnt);//统计c中比b大的 所以要把a的记录清空
for(int i=1;i<=n;i++)
{
int x;
cin>>x;
cnt[++x]++;
}
for(int i=1;i<N;i++)s[i]=s[i-1]+cnt[i];
for(int i=1;i<=n;i++)cs[i]=s[N-1]-s[b[i]];
LL res=0;
for(int i=1;i<=n;i++)
{
res+=1LL*as[i]*cs[i];//防止int*int溢出 先转型LL
}
cout<<res<<endl;
return 0;
}
②用二分,将a,c数组排序后,枚举b数组 在a数组找到最后一个小于b的位置,在c数组找到大于b的第一个位置
然后根据index关系算出个数
请看代码:
#include <iostream>
#include <algorithm>
using namespace std;
const int N=100010;
typedef long long LL;
int a[N],b[N],c[N];
/*
排序 二分出边界下标统计比b[i]小和大的数的个数
*/
int main(void)
{
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)cin>>a[i];
for(int i=0;i<n;i++)cin>>b[i];
for(int i=0;i<n;i++)cin>>c[i];
sort(a,a+n),sort(c,c+n);
LL res=0;
for(int i=0;i<n;i++)
{
int x=b[i];
int l=0,r=n-1;
while(l<r)//找到<x最大的数
{
int mid=l+r+1>>1;
if(a[mid]<x)l=mid;
else r=mid-1;
}
if(a[l]>=x)continue;//找不到
int sl=l+1;//下标从0开始 所以加1即长度
l=0,r=n-1;
while(l<r)//找到大于x的第一个数
{
int mid=l+r>>1;
if(c[mid]>x)r=mid;
else l=mid+1;
}
if(c[l]<=x)continue;
int sr=n-l;//(下标从0开始 所以n-l即长度)
res+=1LL*sl*sr;
}
cout<<res<<endl;
return 0;
}
③双指针思想:
找位置是单调的 即先把a b c 排序 O(nlogn)
再枚举b O(n) b增加 a中小于b的数也会增加 (单调不减) c中大于b的数也会单调减少(单调不增)
避免了重复寻找
看代码:
#include <iostream>
#include <algorithm>
using namespace std;
const int N=100010;
typedef long long LL;
int a[N],b[N],c[N];
int main(void)
{
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)cin>>a[i];
for(int i=0;i<n;i++)cin>>b[i];
for(int i=0;i<n;i++)cin>>c[i];
sort(a,a+n),sort(b,b+n),sort(c,c+n);
LL res=0;
//b↑a↑c↑
for(int i=0,j=0,k=0;i<n;i++)
{
int x=b[i];
while(a[j]<x&&j<n)j++;//找到第一个大于等于x的下标(下标从0开始 所以idx即为长度)
while(c[k]<=x&&k<n)k++;//找到第一个大于x的的下标(从0开始 n-idx为长度)
res+=1LL*j*(n-k);
}
cout<<res<<endl;
return 0;
}