题目 :Kattis - aplusb A+B Problem
Given N integers in the range [−50000,50000], how many ways are there to pick three integers ai, aj, ak, such that i, j, k are pairwise distinct and ai+aj=ak? Two ways are different if their ordered triples (i,j,k) of indices are different.
Input
The first line of input consists of a single integer N
(1≤N≤200000). The next line consists of N space-separated integers a1,a2,…,aN
Output
Output an integer representing the number of ways.
Sample Input 1
4
1 2 3 4
Sample Output 1
4
Sample Input 2
6
1 1 3 3 4 6
Sample Output 2
10
题意
给一组数,求能够满足 a i + a j = a k a_i+a_j=a_k ai+aj=ak的三元组 ( i , j , k ) (i,j,k) (i,j,k)有多少个,其中 i , j , k i,j,k i,j,k互不相同。
思路
-
统计 a [ i ] a[i] a[i]的出现的次数 c n t [ a [ i ] ] cnt[ a[i] ] cnt[a[i]]。
-
构造多项式
其中 x x x的指数为 a [ i ] a[i] a[i],这一项的系数为 c n t [ a [ i ] ] cnt[ a[i] ] cnt[a[i]]。 -
为什么要这样构造呢?
- 首先我们要计算 a [ i ] + a [ j ] a[i]+a[j] a[i]+a[j]的每一种可能值会有多少个。比如 a [ i ] + a [ j ] a[i]+a[j] a[i]+a[j]可能等于 x x x,而 a [ i ] + a [ j ] = x a[i]+a[j] = x a[i]+a[j]=x总共有t种可能, a [ k ] = x a[k] = x a[k]=x有w种可能。则满足 a [ i ] + a [ j ] = a [ k ] = x a[i]+a[j] = a[k] = x a[i]+a[j]=a[k]=x,就有 t ∗ w t*w t∗w种可能。
- 那么我们要计算满足 a i + a j = a k a_i+a_j=a_k ai+aj=ak的三元组 ( i , j , k ) (i,j,k) (i,j,k)有多少个。对于每一种 x x x,求 t ∗ w t*w t∗w的累加就是答案。
- a [ k ] = x a[k] = x a[k]=x的可能情况很好求,计算 c n t [ a [ i ] ] cnt[ a[i] ] cnt[a[i]]就可以了。所以问题的关键在于计算 a [ i ] + a [ j ] a[i]+a[j] a[i]+a[j]的每一种可能值会有多少个。
- a [ i ] + a [ j ] a[i]+a[j] a[i]+a[j]是计算和,而在多项式乘法中,就是A的每一项和B的每一项相乘,系数相乘,指数相加。那么联想到这里,我们可以将 a [ i ] a[i] a[i]当做指数,系数为 a [ i ] a[i] a[i]的数量(即数组中中出现了几次 a [ i ] a[i] a[i]这个值),求出A*A = C的多项式C的值,然后 a [ i ] + a [ j ] = y a[i]+a[j]=y a[i]+a[j]=y可能值就是多项式C的第 x y x^y xy项的系数。
-
处理数据!
题目中所给的数据范围是 [−50000,50000],有负数,在作指数时,不好计算,因为我们是以数组下标代表指数的,显然负数无法用下标代表,所以我们统一加上50000,把数变成正数,范围为[0, 100000]。
-
但是题目还有个要求: i , j , k i,j,k i,j,k互不相同。所以我们要把重复情况给减去。
-
有哪些重复情况呢?
-
1、 i = j i = j i=j
首先我们在计算 a [ i ] + a [ j ] = y a[i]+a[j] = y a[i]+a[j]=y的可能数 t t t时,应该减去 a [ i ] + a [ i ] a[i]+a[i] a[i]+a[i]这种情况。 -
2、 i = k i = k i=k 或 j = k j = k j=k
显然这个时候就出现了0,即 a i + 0 = a i 即 ( a j = 0 ) a_i+0=a_i即(a_j = 0) ai+0=ai即(aj=0), 0 + a j = a j 即 ( a i = 0 ) 0+a_j=a_j即(a_i = 0) 0+aj=aj即(ai=0)。
对于每一个 a k a_k ak,这两种情况都需要减去。显然这两种情况的总数量就是 2 ∗ z e r o 2*zero 2∗zero。但是如果 a k = 0 a_k= 0 ak=0,即 a i = a j = 0 a_i = a_j = 0 ai=aj=0,也即 a i + a j = a k = 0 + 0 = 0 a_i + a_j = a_k = 0 + 0 = 0 ai+aj=ak=0+0=0时,显然多减去了一些。
2 ∗ z e r o 2*zero 2∗zero 实际上包含的是 ( 0 1... z e r o ) + a k = a k (0_{1...zero}) + a_k = a_k (01...zero)+ak=ak 和 a k + ( 0 1... z e r o ) = a k a_k + (0_{1...zero}) = a_k ak+(01...zero)=ak ,如果 a k = 0 a_k= 0 ak=0,则 k ∈ ( 1... z e r o ) k \in(1...zero) k∈(1...zero)。也就是说上面两个式子中 ( 0 1... z e r o ) + a k = a k (0_{1...zero}) + a_k = a_k (01...zero)+ak=ak 包含了 0 k + 0 k = 0 k 0_k+ 0_k = 0_k 0k+0k=0k, a k + ( 0 1... z e r o ) = a k a_k + (0_{1...zero}) = a_k ak+(01...zero)=ak 也包含了 0 k + 0 k = 0 k 0_k+ 0_k = 0_k 0k+0k=0k。
但是在情况1里面,我们已经减去了 i = j i = j i=j的情况,即这两次 0 k + 0 k = 0 k 0_k+ 0_k = 0_k 0k+0k=0k都已经在情况1里面减去了,这里多减了,所以要加回来。
-
-
答案可能很大,所以要用long long
AC代码
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
#define maxn (1<<19)
#define pi acos(-1)
using namespace std;
typedef long long LL;
struct Complex
{
double re, im;
Complex(double r = 0.0, double i = 0.0)
{
re = r, im = i;
}
void print()
{
printf("%lf %lf\n", re, im);
}
};
Complex operator +(const Complex&A, const Complex&B)
{
return Complex(A.re + B.re, A.im + B.im);
}
Complex operator -(const Complex&A, const Complex&B)
{
return Complex(A.re - B.re, A.im - B.im);
}
Complex operator *(const Complex&A, const Complex&B)
{
return Complex(A.re * B.re - A.im * B.im, A.re * B.im + A.im * B.re);
}
Complex a[maxn], inv[2][maxn];
int N, rev[maxn];
int n, zero, cnt[maxn], sa[maxn];
LL num[maxn];
void FFT(Complex*a, int f)
{
Complex x, y;
for(int i = 0; i < N; i++)
if(i < rev[i]) //不加这条if会交换两次(就是没交换)
swap(a[i], a[rev[i]]);
for(int i = 1; i < N; i <<= 1) //i是准备合并序列的长度的二分之一
for(int j = 0, t = N / (i << 1); j < N; j += i << 1) //i*2是准备合并序列的长度,j是合并到了哪一位(第某段的开头的坐标),t表示每一份单位根占单位圆的多少
for(int k = 0, l = 0; k < i; k++, l += t) //k是第某段内的第i位(只扫描前一半,后面一半可以同时求)
{
x = inv[f][l] * a[j + k + i]; //inv[f][l]表示第L份单位根
y = a[j + k];
a[j + k] = y + x;
a[j + k + i] = y - x;
}
if(f)
for(int i = 0; i < N; i++)
a[i].re /= N;
}
void Init()
{
int bit = 0;
while((1 << bit) < N)
bit++;
for(int i = 0; i < N; i++)//预处理逆反位置
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
for(int i = 0; i < N; i++)//预处理单位根
inv[0][i] = inv[1][i] = Complex(cos(2 * pi * i / N), sin(2 * pi * i / N)), inv[1][i].im = -inv[0][i].im;
}
void pre()
{
zero = 0;
memset(cnt, 0, sizeof(cnt));
memset(num, 0, sizeof(num));
scanf("%d", &n);
int maxx = -1;
for(int i = 0; i < n; i++)
{
scanf("%d", &sa[i]);
cnt[sa[i] + 50000]++;
maxx = max(maxx, sa[i] + 50000);
if(sa[i] == 0)
zero++;
}
for(N = 1; N < maxx; N <<= 1);
N <<= 1;
for(int i = 0; i < N; i++)
{
if(i <= maxx)
a[i].re = cnt[i];
else
a[i].re = 0;
a[i].im = 0;
}
}
void work()
{
Init();
FFT(a, 0);
for(int i = 0; i < N; i++)
a[i] = a[i] * a[i];
FFT(a, 1);
for(int i = 0; i < N; i++)
num[i] = (LL)(a[i].re + 0.5);
for(int i = 0; i < n; i++)
num[(sa[i] + 50000) * 2]--;
LL ans = 0;
for(int i = 0; i < n; i++)
{
ans += num[ sa[i] + 2 * 50000 ];
ans -= 2 * zero;
if(sa[i] == 0)
ans += 2;
}
printf("%lld\n", ans);
}
int main()
{
pre();
work();
return 0;
}