题意简述
给出
n
个元素的数列,问有多少三元组
数据范围
1≤105
1≤ai≤3×104
思路
分块+FFT。
对于至少存在两个在一块内的,我们可以通过维护一个数组来实现统计,这一部分是
O(nB)
的。
对于
i,j,k
都不在一块内的,我们枚举块,把当前块左边的和右边的FFT,再扫描当前快统计答案,这一部分是
O(VlogVnB)
的,
V
是值域。
当
但是还要考虑一下FFT的常数QwQ……
代码
#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
using namespace std;
const double pi=acos(-1);
#define N 65600
#define block 2000
struct C{
double x,y;
C(double _x=0,double _y=0)
{
x=_x,y=_y;
}
void operator = (const double &n1)
{
x=n1;
}
C operator * (const C &n1)
{
return C(x*n1.x-y*n1.y,x*n1.y+y*n1.x);
}
C operator + (const C &n1)
{
return C(x+n1.x,y+n1.y);
}
C operator - (const C &n1)
{
return C(x-n1.x,y-n1.y);
}
double real()
{
return x;
}
void operator /= (const int &n1)
{
x/=n1,y/=n1;
}
void operator *= (const C &n1)
{
double a=x,b=y;
x=a*n1.x-b*n1.y;
y=a*n1.y+b*n1.x;
}
void operator +=(const double &n1)
{
x+=n1;
}
}A[N],B[N];
int seq[100010],r[N],t[N];
int n,lim,ti;
long long ans;
void fft(C *A,int f)
{
for (int i=0;i<lim;i++)
if (i<r[i])
swap(A[i],A[r[i]]);
for (int i=1;i<lim;i<<=1)
{
C wn(cos(pi/i),f*sin(pi/i));
for (int j=0;j<lim;j+=(i<<1))
{
C w=1;
for (int k=0;k<i;k++,w*=wn)
{
C x=A[j+k],y=w*A[j+k+i];
A[j+k]=x+y,A[j+k+i]=x-y;
}
}
}
}
void getint(int &ret)
{
ret=0;
char ch=getchar();
while (ch<'0'||ch>'9')
ch=getchar();
while (ch>='0'&&ch<='9')
ret=ret*10+ch-'0',ch=getchar();
}
int main()
{
getint(n);
for (int i=0;i<n;i++)
getint(seq[i]);
lim=65536,ti=16;
for (int i=0;i<lim;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<ti-1);
for (int i=1;block*(i+1)<n;i++)
{
memset(A,0,sizeof(A));
memset(B,0,sizeof(B));
for (int j=0;j<block*i;j++)
A[seq[j]]+=1;
for (int j=block*(i+1);j<n;j++)
B[seq[j]]+=1;
fft(A,1),fft(B,1);
for (int j=0;j<lim;j++)
A[j]*=B[j];
fft(A,-1);
for (int j=0;j<lim;j++)
A[j]/=lim;
for (int j=block*i;j<block*(i+1);j++)
ans+=(long long)(A[seq[j]*2].real()+0.5);
}
for (int i=0;block*i<n;i++)
{
memset(t,0,sizeof(t));
for (int j=block*i;j<n;j++)
t[seq[j]]++;
for (int j=block*i;j<min(block*(i+1),n);j++)
{
t[seq[j]]--;
for (int k=block*i;k<j;k++)
if (seq[j]+seq[j]-seq[k]>=0)
ans+=t[seq[j]+seq[j]-seq[k]];
}
memset(t,0,sizeof(t));
for (int j=0;j<block*i;j++)
t[seq[j]]++;
for (int j=min(block*(i+1),n)-1;j>=block*i;j--)
for (int k=min(block*(i+1),n)-1;k>j;k--)
if (seq[j]+seq[j]-seq[k]>=0)
ans+=t[seq[j]+seq[j]-seq[k]];
}
printf("%lld",ans);
return 0;
}