题目地址
给你一堆数,问你满足
ai+aj=ak
的
(i,j,k)
三元组的数量。
因为有负数,所以给每个数右移50000
然后几乎是一个裸的FFT,就这么提交,然后WA了。
之后想到了忘记判不符合的情况了,只有一个地方要考虑一下。就是 ai=0且j==k或者aj=0且i=k 容易想到这个次数就是0的个数 tot0∗2 ,当 ak=0 时 是 (tot0−1)∗2
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cmath>
using namespace std;
typedef long long LL;
const int MAXN = 262144*2+1000;
const int INF = 0x3f3f3f3f;
const double pi = acos(-1.0);
struct cp
{
double x,y;
cp() {}
cp(double x,double y):x(x),y(y) {}
inline double real() { return x; }
inline cp operator * (const cp& r) const { return cp(x*r.x - y*r.y,x*r.y+y*r.x); }
inline cp operator - (const cp& r) const { return cp(x-r.x,y-r.y); }
inline cp operator + (const cp& r) const { return cp(x+r.x,y+r.y); }
};
cp a[MAXN],b[MAXN];
LL r[MAXN],res[MAXN],ax[MAXN],bx[MAXN];
void fft_init(int nm,int k)
{
for (int i=0;i<nm;i++) r[i] = (r[i>>1]>>1) | ((i &1) << (k-1));
}
void fft(cp ax[],int nm,int op)
{
for (int i=0;i<nm;i++) if (i<r[i]) swap(ax[i],ax[r[i]]);
for (int h=2,m=1;h<=nm;h<<=1,m<<=1)
{
cp wn = cp(cos(op*2*pi/h),sin(op*2*pi/h));
for (int i=0;i<nm;i+=h)
{
cp w(1,0);
for (int j=i;j<i+m;++j,w=w*wn)
{
cp t=w*ax[j+m];
ax[j+m] = ax[j]-t;
ax[j] = ax[j]+t;
}
}
}
if (op == -1) for (int i=0;i<nm;i++) ax[i].x /=nm;
}
void trans(LL ax[],LL bx[],int n,int m)
{
int nm=1,k=0;
while (nm < 2*n || nm < 2*m ) nm<<=1,++k;
for (int i=0;i<n;i++) a[i] = cp(ax[i],0);
for (int i=0;i<m;i++) b[i] = cp(bx[i],0);
for (int i=n;i<nm;i++) a[i] = cp(0,0);
for (int i=m;i<nm;i++) b[i] = cp(0,0);
fft_init(nm,k);
fft(a,nm,1);fft(b,nm,1);
for (int i=0;i<nm;i++) a[i] = a[i]*b[i];
fft(a,nm,-1);
nm = n+m-1;
for (int i=0;i<nm;i++) res[i] = (LL)(a[i].real()+0.5);
}
int n;
LL l[MAXN];
int main()
{
LL anum = 0,tot0=0;
scanf("%d",&n);
for (int i=1;i<=n;i++)
{
scanf("%lld",&l[i]);
if (l[i] == 0) tot0++;
l[i] += 50000;
ax[l[i]] ++;
anum = max(anum,l[i]);
}
trans(ax,ax,anum+1,anum+1);
for (int i=1;i<=n;i++) res[l[i]*2] --;
LL ans=0;
for (int i=1;i<=n;i++)
{
ans += res[l[i]+50000];
if (l[i] == 50000) ans -= (tot0-1)*2;
else ans -= tot0*2;
}
printf("%lld\n",ans);
return 0;
}