题目描述
小明是今年超级跳棋比赛的裁判,每轮有三名选手参加,结束时统计的分数一定是正整数,形如 a:b:c。小明的任务是在一块特殊的计分板上展示分数,他一共准备了 n 块写有正整数 x1,x2,……,xn的卡片,可供填写在 a、b、c 的位置上。此外,小明了解到超级跳棋的规则,他发现 a、b、c 之间最多相差 k 倍,例如 c/a>k就是不合法的分数。为了检验他准备得是否充分,你需要计算小明可以在计分板上摆放出多少种不同的分数,即 (a,b,c) 这样的三元组有多少个。
限制
1s 256M
对于 20% 的数据,3≤n≤100,000,k=1,1≤xi≤100,0003≤n≤100,000,k=1,1≤xi≤100,000
对于另外 20% 的数据,3≤n≤100,1≤k≤100,1≤xi≤1003≤n≤100,1≤k≤100,1≤xi≤100
对于另外 30% 的数据,3≤n≤100,000,1≤k,xi≤10^9且所有xi互不相同3≤n≤100,000,1≤k,xi≤10^9且所有xi互不相同
对于另外 30% 的数据,3≤n≤100,000,1≤k,xi≤10^9,3≤n≤100,000,1≤k,xi≤10^9
输入格式
第一行,两个整数 n 和 k
第二行,n 个整数 x1,x2,……,xn
输出格式
一个整数,表示 (a,b,c) 三元组的个数
输入样例
5 2
1 1 2 2 3
输出样例
9
样例解释
小明可以摆出的 a:b:c 有以下这些:1:1:2、1:2:1、2:1:1、1:2:2、2:1:2、2:2:1、2:2:3、2:3:2、3:2:2。由于 k=2,k=2,1 和 3 不能同时出现。
由于三元组中任意两个数最多相差k倍,所以只需比较一个三元组中最大数和最小数相差多少倍,就可以判断它是否符合条件。因此,我们将x数组排序,枚举三元组的第一个数a,确定一个a后,就可以确定剩下两个数b和c的在数组x中的下标范围。如样例中,当a=x[1]时,b,c只能从x[2~4]中选。
本来由一个三元组合可以衍生出六(3*2*1)种不同的排列,但卡片上的数可能重复,便产生了分类讨论的必要。如下:
- 当a,b,c各不相同时,由于我们枚举a,a已确定,因此只要在可选数(与a相差最多k倍的数)中选出两个不同的即可。
- 当b=c时,要在可选数中选择两个相同的数,我们得提前预处理好每个x[i]的相同数的数量,才能判断一个可选数相同数的数量是否大于等于2。
- 当a=b时,前提是当前枚举的a的相同数的数量大于等于2,剩下一个数从可选数中随便选。
- 当a=b=c时,如果当前枚举的a的相同数的数量大于等于3即可。
如果你理解了上述的话,那么最后对答案的统计也就很简单了。比如第一种情况,那么对于一个a而言,假设可选数中不同数的数量是s,答案要加上s*(s-1)/2*6。剩下的情况留给读者自己思考。
附上Yuns的代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int read(){
int x=0,f=1; char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1; ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0'; ch=getchar();}
return x*f;
}
int a[100005],b[100005],d[100005],p,k,n;
int s[100005];
ll ans=0;
int find(int x){
int l=1,r=p,ans=-1;
ll q=(ll)x*k;
while (l<=r){
int mid=(l+r)>>1;
if (d[mid]<=q){ l=mid+1; ans=mid;} else r=mid-1;
}
return ans;
}
int main(){
n=read(),k=read();
for (int i=1;i<=n;i++)
a[i]=read();
sort(a+1,a+1+n);
for (int i=1;i<=n;i++){
if (a[i]!=a[i-1]) p++,d[p]=a[i];
b[p]++;
}
for (int i=1;i<=p;i++){
if (b[i]>=2) s[i]=s[i-1]+1;
else s[i]=s[i-1];
}
for (int i=1;i<=p;i++){
int q1=find(d[i]);
if (q1==-1||q1<i) continue;
int kd=q1-i;
ans+=1ll*kd*(kd-1)*3;
ans+=1ll*(s[q1]-s[i])*3;
if (b[i]>=2) ans+=1ll*kd*3;
if (b[i]>=3) ans++;
}
cout<<ans<<endl;
return 0;
}