题意
给一可重集 S S S 及质数 p p p,问能从集合中选出多少种本质不同的三元组 ( x , y , z ) ( x ≤ y ≤ z ) (x,y,z)\ (x\leq y\leq z) (x,y,z) (x≤y≤z) 使得 x y z ≡ 1 ( m o d p ) xyz\equiv 1\pmod p xyz≡1(modp)? ∣ S ∣ ≤ 2333 , p ≤ 2 30 |S|\leq 2333,p\leq 2^{30} ∣S∣≤2333,p≤230
70% 解
保证 x < p ∣ x ∈ S x<p|x\in S x<p∣x∈S。首先排序所有数并预处理出其逆元。接着枚举 x , y ( x ≤ y ) x,y(x\leq y) x,y(x≤y),算出 z = x − 1 y − 1 z=x^{-1}y^{-1} z=x−1y−1,检查是否 y ≤ z y\leq z y≤z 以及是否 z ∈ S z\in S z∈S(二分)。注意处理有数字相等的情况。
100% 解
排序时优先根据 m o d p \mod p modp 结果排序,并将所有数字去重后的结果另外存储一份。仍然枚举 x , y ( x ≤ y ) x,y(x\leq y) x,y(x≤y),算出 z z z,统计与 z z z 在同一剩余系下的数字个数。同样注意处理有数字相等的情况。
代码:
#include<bits/stdc++.h>
using namespace std;
int getint(){
int ans=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
int p;
bool cmp(int a,int b){
if(a%p!=b%p)return a%p<b%p;
return a<b;
}
bool cmp2(int a,int b){
return a%p<b%p;
}
inline int qpow(int x,int y,int z){
int ans=1;
while(y){
if(y&1)ans=ans*1ll*x%z;
x=x*1ll*x%z;
y>>=1;
}
return ans;
}
int a[3000];
int inv[3000];
int uniq[3000];
int main(){
int n=getint();
p=getint();
int sqrtn=sqrt(n);
for(int i=0;i<n;i++){
a[i]=getint();
}
sort(a,a+n,cmp);
memcpy(uniq,a,sizeof(uniq));
int r_=unique(uniq,uniq+n)-uniq;
for(int i=0;i<n;i++){
inv[i]=qpow(a[i],p-2,p);
}
int ans=0;
//a[i]<a[j]<=a[k]
for(int i=0;i<n;i++){
if(a[i]==0)continue;
if(i&&a[i-1]==a[i])continue;
for(int j=i+1;j<n;j++){
if(j&&a[j-1]==a[j])continue;
if(a[i]==a[j])continue;
int ak=inv[i]*1ll*inv[j]%p;
//cout<<">> ? "<<a[i]<<" "<<a[j]<<" "<<ak<<endl;
if(ak%p<a[j]%p)continue;
if(ak%p==a[j]%p){
ans+=(a[j+1]==a[j]);
//cout<<"> "<<ans<<endl;
ans+=upper_bound(uniq,uniq+r_,ak,cmp2)
-upper_bound(uniq,uniq+r_,a[j],cmp);
//cout<<"> "<<ans<<endl;
}else{
ans+=upper_bound(uniq,uniq+r_,ak,cmp2)
-lower_bound(uniq,uniq+r_,ak,cmp2);
}
}
}
//a[i]=a[j]<=a[k]
for(int i=0;i<n-1;i++){
if(a[i]!=a[i+1])continue;
if(i&&a[i-1]==a[i])continue;
if(a[i]==0)continue;
int ak=inv[i]*1ll*inv[i]%p;
if(ak%p<a[i]%p)continue;
if(ak%p==a[i]%p){
ans+=(a[i+2]==a[i]);
//cout<<"> "<<ans<<endl;
ans+=upper_bound(uniq,uniq+r_,ak,cmp2)
-upper_bound(uniq,uniq+r_,a[i],cmp);
//cout<<"> "<<ans<<endl;
}else{
ans+=upper_bound(uniq,uniq+r_,ak,cmp2)
-lower_bound(uniq,uniq+r_,ak,cmp2);
}
}
printf("%d\n",ans);
return 0;
}