题目要求求出(i,j,k)的对数满足i<j<k且a[j]-a[i]=a[k]-a[j]
式子变形得到a[j]*2=a[i]+a[k];
考虑对于每个a[j]求出左边的i和右边的k
发现这个可以用生成函数来搞一搞,生成函数的指数代表数值系数代表个数,左右两边分别搞出一个多项式然后FFT一下即可
但是复杂度nmlogm(m=max(a[i]))会T
考虑分块,对于每个块[l,r]FFT出i<l且l<=j<=r且k>r的方案,块内暴力即可
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const double pi=3.141592653589793238462643383279502884197169399375105820974944;
const int maxn=400010;
struct cp{
double r,i;
cp(double _r=0,double _i=0):r(_r),i(_i){}
cp operator + (cp x){return cp(r+x.r,i+x.i);}
cp operator - (cp x){return cp(r-x.r,i-x.i);}
cp operator * (cp x){return cp(r*x.r-i*x.i,r*x.i+i*x.r);}
}A[maxn],B[maxn],C[maxn],D[maxn];
int n,a[maxn],L[maxn],R[maxn],blo,N,LL,rev[maxn],dig[maxn],mx=0;
ll ans=0;
void FFT(cp a[],int flag){
for (int i=0;i<N;++i) D[i]=a[rev[i]];
for (int i=0;i<N;++i) a[i]=D[i];
for (int i=2;i<=N;i<<=1){
cp wn(cos(2*pi/i),flag*sin(2*pi/i));
for (int k=0;k<N;k+=i){
cp w(1,0);
for (int j=k;j<k+i/2;++j){
cp x=a[j],y=a[j+i/2]*w;
a[j]=x+y; a[j+i/2]=x-y;
w=w*wn;
}
}
}if (flag==-1) for (int i=0;i<N;++i) a[i].r/=N;
}
void remake(){
for (N=1,LL=0;N<mx;N<<=1,LL++);
N<<=1; LL++;
for (int i=0;i<N;++i) dig[i]=rev[i]=0;
for (int i=0;i<N;++i){
int len=0;
for (int t=i;t;t>>=1) dig[len++]=t&1;
for (int j=0;j<LL;++j) rev[i]=(rev[i]<<1)|(dig[j]);
}
}
void work(){
for (int i=0;i<n;++i) ++R[a[i]];
for (int i=0;i<n;i+=blo){
int l=i,r=min(i+blo-1,n-1);
for (int j=l;j<=r;++j) --R[a[j]];
for (int j=l;j<=r;++j){
for (int k=j+1;k<=r;++k){
int tmp=2*a[j]-a[k];
if (tmp>=0) ans+=L[tmp];
tmp=2*a[k]-a[j];
if (tmp>=0) ans+=R[tmp];
}
L[a[j]]++;
}
}
for (int i=0;i<=mx;++i) L[i]=R[i]=0;
for (int i=0;i<n;i+=blo){
int l=i,r=min(n-1,i+blo-1);
mx=-1;
for (int j=0;j<l;++j) ++L[a[j]],mx=max(mx,a[j]);
for (int j=r+1;j<n;++j) ++R[a[j]],mx=max(mx,a[j]);
remake();
for (int j=0;j<N;++j) A[j]=cp(L[j],0.0);
for (int j=0;j<N;++j) B[j]=cp(R[j],0.0);
FFT(A,1); FFT(B,1);
for (int j=0;j<N;++j) C[j]=A[j]*B[j];
FFT(C,-1);
for (int j=l;j<=r;++j) ans+=(ll)(C[a[j]<<1].r+0.5);
for (int j=0;j<=mx;++j) L[j]=R[j]=0;
}
printf("%lld\n",ans);
}
int main(){
scanf("%d",&n);
for (int i=0;i<n;++i){
scanf("%d",&a[i]);
if (a[i]>mx) mx=a[i];
}
blo=min(n,(int)sqrt(n)*6);
work();
}