想不到怎么办,分块试一下
块内的可以直接DP出来,块与块之间的,化一下公式发现
ai+ak=2aj
,所以处理第i块,维护1~i-1块每个数出现次数,i+1~n块每个数出现次数,FFT卷积一下,枚举块内的每个值的两倍累加上答案
鉴于FFT常数大,块的大小定为1800左右比较合适
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn = 310000;
const int maxm = 310000;
const int maxnn = 10000;
const double pi=acos(-1);
struct E
{
double x,y;
E(){}
E(double _x,double _y){x=_x;y=_y;}
}w[maxm],c[maxm],b[maxm]; int M,lm; int id[maxm];
E operator +(E x,E y){return E(x.x+y.x,x.y+y.y);}
E operator -(E x,E y){return E(x.x-y.x,x.y-y.y);}
E operator *(E x,E y){return E(x.x*y.x-x.y*y.y,x.y*y.x+x.x*y.y);}
void FFT(E *s,int sig)
{
for(int i=0;i<M;i++) if(i<id[i]) swap(s[i],s[id[i]]);
for(int m=2;m<=M;m<<=1)
{
int t=m>>1,tt=M/m;
for(int i=0;i<t;i++)
{
E wn=sig==1?w[i*tt]:w[M-i*tt];
for(int j=i;j<M;j+=m)
{
E tx=s[j],ty=wn*s[j+t];
s[j]=tx+ty;
s[j+t]=tx-ty;
}
}
}
if(sig==-1)for(int i=0;i<M;i++) s[i].x/=M;
}
int a[maxn],bel[maxn],st[maxnn],n,N,m;
int g1[maxm],g2[maxm];
ll ret;
int main()
{
scanf("%d",&n); N=1800; m=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]); m=max(m,a[i]);
bel[i]=(i-1)/N+1;
}
for(int i=1;i<=bel[n];i++) st[i]=(i-1)*N+1;
st[bel[n]+1]=n+1;
for(M=1,lm=0;M<=(m+m);M<<=1,lm++);
for(int i=0;i<M;i++)id[i]=(id[i>>1]>>1)|((i&1)<<lm-1);
for(int km=2;km<=M;km<<=1)
{
int t=km>>1,tt=M/km;
for(int i=0;i<t;i++)
{
w[i*tt]=E(cos(2*pi*i/km),sin(2*pi*i/km));
w[M-i*tt]=E(cos(2*pi*i/km),sin(-2*pi*i/km));
}
}
for(int i=0;i<=m;i++)g1[i]=g2[i]=0;
for(int i=1;i<=n;i++)g2[a[i]]++;
ret=0;
for(int i=1;i<=bel[n];i++)
{
for(int l=st[i];l<st[i+1];l++) g2[a[l]]--;
if(i>1&&i<bel[n])
{
for(int l=0;l<M;l++) c[l]=E(g1[l],0),b[l]=E(g2[l],0);
FFT(c,1); FFT(b,1);
for(int l=0;l<M;l++) c[l]=c[l]*b[l];
FFT(c,-1);
for(int l=st[i];l<st[i+1];l++)
ret+=(ll)(c[a[l]<<1].x+0.5);
}
for(int l=st[i];l<st[i+1];l++)
{
for(int r=l+1;r<st[i+1];r++)
{
if(a[r]>=a[l])
{
if(2*a[l]>=a[r]) ret+=g1[2*a[l]-a[r]];
if(2*a[r]-a[l]<=m) ret+=g2[2*a[r]-a[l]];
}
else
{
if(2*a[l]-a[r]<=m) ret+=g1[2*a[l]-a[r]];
if(2*a[r]-a[l]>=0) ret+=g2[2*a[r]-a[l]];
}
}
g1[a[l]]++;
}
}
printf("%lld\n",ret);
return 0;
}