Problem
There are n nonnegative integers a1…n which are less than p. HazelFan wants to know how many pairs i,j(1≤i
Idea
公式 (1), (2), (3), (4) 均等价。
当 ai=aj 时,公式 (4) 等价于 3a2i≡0modp ,由于 p 为质数,当且仅当 p=3 时,公式可能成立,否则一定不存在其它质数 p 使得等式 3×a2imodp=0 成立。此时,因为 ai 的取值一定小于 p ,故仅 ai=1,2 成立,暴力判断即可。
当
ai≠aj
时,
若 a3i≡a3jmodp 成立,则表示是一组合法的解.
因此,当 p≠3 时, 故快速乘法处理每个 ai 获取 a3imodp 。获取相同 a3imodp 的个数,组合原理求可行方案数 number×(number−1)2 。
当然,由于此时的前提是 ai≠aj ,而 a3i≡a3imodp 必定合法,故应优先移除 ai=aj 对答案的贡献。
……此题貌似我 C++ 敲得太渣,竟然没有 Java 代码跑得快。
Code(C++ Edition)
#include<bits/stdc++.h>
using namespace std;
const int N = 100000 + 10;
int T, n;
long long p, a[N];
long long quick_mul(long long a, long long b) {
long long res = 0;
while(b) {
res = (res + (b%2*a) % p) % p;
(a *= 2) %= p;
b >>= 1;
}
return res;
}
long long calc() {
long long ret = 0, cnt;
for(int i=1;i<=n;) {
cnt = 0;
for(int j=i;j<=n;j++)
if(a[i] == a[j]) cnt++;
else break;
ret += cnt * (cnt - 1) / 2;
i += cnt;
}
return ret;
}
int main()
{
scanf("%d", &T);
while(T-- && scanf("%d %lld", &n, &p)!=EOF)
{
long long ans = 0;
for(int i=1;i<=n;i++)
scanf("%lld", &a[i]);
if(p == 3) {
int cntOne = 0, cntTwo = 0;
for(int i=1;i<=n;i++)
if(a[i] == 1) cntOne++;
else if(a[i] == 2) cntTwo++;
ans = cntOne * 1ll * (cntOne-1) / 2 + cntTwo * 1ll * (cntTwo - 1) / 2;
printf("%lld\n", ans);
} else {
sort(a+1, a+n+1);
ans -= calc();
for(int i=1;i<=n;i++)
a[i] = quick_mul(quick_mul(a[i], a[i]), a[i]);
sort(a+1, a+n+1);
ans += calc();
printf("%lld\n", ans);
}
}
}
Code(Java Edition)
import java.io.*;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.StringTokenizer;
public class Main {
public static void main(String[] args) throws IOException {
SolverHDU_6128 solver = new SolverHDU_6128();
solver.run();
}
}
class SolverHDU_6128 {
private static final int N = 100000 + 10;
private static final BigInteger THREE = BigInteger.valueOf(3);
private long[] a = new long[N];
private BigInteger mod;
private int n;
long calc() {
long ret = 0, cnt;
for(int i=1;i<=n;) {
cnt = 0;
for (int j=i;j<=n;j++)
if (a[i] == a[j]) cnt++;
else break;
ret += cnt * (cnt-1) / 2;
i += cnt;
}
return ret;
}
void run() throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st = new StringTokenizer(br.readLine());
int T = Integer.parseInt(st.nextToken());
while (T-- != 0) {
st = new StringTokenizer(br.readLine());
n = Integer.parseInt(st.nextToken());
long p = Long.parseLong(st.nextToken());
st = new StringTokenizer(br.readLine());
for (int i=1;i<=n;i++) {
a[i] = Long.parseLong(st.nextToken());
}
if (p == 3) {
long cntOne = 0, cntTwo = 0;
for (int i=1;i<=n;i++)
if (a[i] == 1) cntOne+=1;
else if (a[i] == 2) cntTwo += 1;
bw.write(String.valueOf(cntOne*(cntOne-1)/2 + cntTwo*(cntTwo-1)/2));
bw.newLine();
} else {
long ans = 0;
Arrays.sort(a, 1, n+1);
ans -= calc();
mod = BigInteger.valueOf(p);
BigInteger b;
for (int i=1;i<=n;i++) {
b = BigInteger.valueOf(a[i]);
a[i] = b.modPow(THREE, mod).longValue();
}
Arrays.sort(a, 1, n+1);
ans += calc();
bw.write(String.valueOf(ans));
bw.newLine();
}
}
bw.flush();
bw.close();
}
}