题目描述
传送门
题目大意:给定一个长度为N的数组A[],求有多少对i, j, k(1<=i
题解
对序列进行分块。然后分情况讨论。
(1)i,j,k在同一个块中,从左向右顺序枚举i,j,对于j后面的数字出现情况用cnt[i]数组动态维护,每次计算答案的时候加上
cnt[2∗a[j]−a[i]]
(2)i,j在同一个块中,k在另一个块中或者j,k在同一个块中,i在另一个块中,用上面类似的方式维护。
(3)i,j,k在三个不同的块中,枚举j所在的块,用两个数组动态维护左右两端数字的出现情况,然后用FFT优化,得到
c[x]
,其中
x=a[i]+a[k]
,然后枚举j所在块中的数字,每次加上
c[a[j]∗2]
块的大小在2000左右比较合适。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#define N 200003
#define LL long long
#define pi acos(-1)
using namespace std;
struct data{
double x,y;
data(double X=0,double Y=0) {
x=X,y=Y;
}
}f[N],g[N];
int n,n1,m,cnt[N],a[N],cntl[N],cntr[N],L,R[N],blocksize,l[N],r[N];
LL c[N];
data operator +(data a,data b){
return data(a.x+b.x,a.y+b.y);
}
data operator -(data a,data b){
return data(a.x-b.x,a.y-b.y);
}
data operator *(data a,data b){
return data(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
void FFT(data a[],int n,int opt)
{
for (int i=0;i<n;i++)
if (i>R[i]) swap(a[i],a[R[i]]);
for (int i=1;i<n;i<<=1) {
data wn=data(cos(pi/i),opt*sin(pi/i));
for (int p=i<<1,j=0;j<n;j+=p) {
data w=data(1,0);
for (int k=0;k<i;k++,w=w*wn){
data x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y; a[j+k+i]=x-y;
}
}
}
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d",&n);
int mx=0;
for (int i=1;i<=n;i++) scanf("%d",&a[i]),mx=max(mx,a[i]);
mx*=2;
for (n1=1;n1<=mx;n1<<=1) L++;
for (int i=0;i<=n1;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
LL ans=0;
blocksize=min(n,2000);
int t=(n-1)/blocksize+1;
for (int i=1;i<=t;i++) {
l[i]=(i-1)*blocksize+1;
r[i]=min(l[i]+blocksize-1,n);
}
for (int i=1;i<=n;i++) cnt[a[i]]++;
for (int k=1;k<=t;k++) {
for (int i=l[k];i<=r[k];i++) cnt[a[i]]--;
for (int i=l[k];i<=r[k];i++)
for (int j=l[k];j<i;j++) {
int v=a[i]-a[j];
v=a[i]+v;
if (v>=0)ans+=(LL)cnt[v];
}
}
for (int i=1;i<=n;i++) cnt[a[i]]++;
for (int k=t;k>=1;k--) {
for (int i=l[k];i<=r[k];i++) cnt[a[i]]--;
for (int i=l[k];i<=r[k];i++)
for (int j=l[k];j<i;j++) {
int v=a[i]-a[j];
v=a[j]-v;
if (v>=0) ans+=(LL)cnt[v];
}
}
for (int k=1;k<=t;k++) {
for (int i=l[k-1];i<=r[k-1];i++) cnt[a[i]]=0;
for (int i=l[k];i<=r[k];i++) cnt[a[i]]++;
for (int i=l[k];i<=r[k];i++) {
cnt[a[i]]--;
for (int j=l[k];j<i;j++) {
int v=a[i]-a[j];
v=a[i]+v;
if (v>=0) ans+=(LL)cnt[v];
}
}
}
if (t>=3) {
for (int i=l[1];i<=r[1];i++) cntl[a[i]]++;
for (int i=l[3];i<=n;i++) cntr[a[i]]++;
for (int k=2;k<=t-1;k++) {
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
for (int i=0;i<=mx;i++) f[i].x=cntl[i];
for (int i=0;i<=mx;i++) g[i].x=cntr[i];
FFT(f,n1,1); FFT(g,n1,1);
for (int i=0;i<n1;i++) f[i]=f[i]*g[i];
FFT(f,n1,-1);
for (int i=0;i<=mx;i++) c[i]=(LL)(f[i].x/n1+0.5);
for (int i=l[k];i<=r[k];i++) ans+=c[a[i]*2];
for (int i=l[k];i<=r[k];i++) cntl[a[i]]++;
for (int i=l[k+1];i<=r[k+1];i++) cntr[a[i]]--;
}
}
printf("%lld\n",ans);
}