题目描述:
给定
n
n
个整数,求有多少个三元组
(i,j,k)
(
i
,
j
,
k
)
满足
1≤i<j<k≤n
1
≤
i
<
j
<
k
≤
n
且
aj−ai=ak−aj
a
j
−
a
i
=
a
k
−
a
j
。
1≤n≤100000,1≤ai≤30000
1
≤
n
≤
100000
,
1
≤
a
i
≤
30000
解题思路:
很容易想到
O(n2)
O
(
n
2
)
的做法:
l[t],r[t] l [ t ] , r [ t ] 分别为 [1,i),(i,n] [ 1 , i ) , ( i , n ] 中 t t 的个数。
好像不能用数据结构优化,但这是一个卷积形式,考虑分块。
枚举每一块,用FFT求出在该块前,
j
j
在该块中,在该块后的答案,再暴力统计其余情况即可。
时间复杂度为 O(n2Slogn+nS),S=nlogn−−−−−√ O ( n 2 S l o g n + n S ) , S = n l o g n 时最好,为 O(nnlogn−−−−−√) O ( n n l o g n )
#include<bits/stdc++.h>
#define ll long long
using namespace std;
int getint()
{
int i=0,f=1;char c;
for(c=getchar();(c!='-')&&(c<'0'||c>'9');c=getchar());
if(c=='-')c=getchar(),f=-1;
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=100005,M=250000;
const double PI=acos(-1.0);
struct Complex
{
double x,y;
Complex(){}
Complex(double _x,double _y):x(_x),y(_y){}
Complex operator + (const Complex &b){return Complex(x+b.x,y+b.y);}
Complex operator - (const Complex &b){return Complex(x-b.x,y-b.y);}
Complex operator * (const Complex &b){return Complex(x*b.x-y*b.y,x*b.y+y*b.x);}
}f1[M],f2[M];
int n,m,mx,S,a[N],l[N],r[N],st[N],ed[N],pos[M];ll ans;
void FFT(Complex *f,int len,int on)
{
for(int i=1;i<len;i++)if(i<pos[i])swap(f[i],f[pos[i]]);
for(int i=1;i<len;i<<=1)
{
Complex wi=Complex(cos(PI/i),on*sin(PI/i));
for(int j=0;j<len;j+=(i<<1))
{
Complex wn=Complex(1,0);
for(int k=j;k<j+i;k++)
{
Complex u=f[k],v=wn*f[k+i];
f[k]=u+v,f[k+i]=u-v,wn=wn*wi;
}
}
}
if(on==-1)for(int i=0;i<len;i++)f[i].x/=len;
}
int main()
{
//freopen("lx.in","r",stdin);
n=getint();S=2000,m=(n-1)/S+1;
for(int i=1;i<=m;i++)st[i]=(i-1)*S+1,ed[i]=i*S;ed[m]=n;
for(int i=1;i<=n;i++)a[i]=getint(),r[a[i]]++,mx=max(mx,a[i]);
int len=1;while(len<=mx+mx)len<<=1;
for(int i=1;i<len;i++)pos[i]=(i&1)?pos[i>>1]>>1|(len>>1):pos[i>>1]>>1;
for(int i=1;i<=m;i++)
{
for(int j=st[i];j<=ed[i];j++)r[a[j]]--;
for(int j=0;j<len;j++)f1[j]=f2[j]=Complex(0,0);
for(int j=1;j<=mx;j++)f1[j]=Complex(l[j],0),f2[j]=Complex(r[j],0);
FFT(f1,len,1),FFT(f2,len,1);
for(int j=0;j<len;j++)f1[j]=f1[j]*f2[j];
FFT(f1,len,-1);
for(int j=st[i];j<=ed[i];j++)ans+=(ll)(f1[2*a[j]].x+0.5);
for(int j=st[i];j<=ed[i];j++)
{
for(int k=st[i];k<j;k++)
if(2*a[j]-a[k]>0)ans+=r[2*a[j]-a[k]];
for(int k=j+1;k<=ed[i];k++)
if(2*a[j]-a[k]>0)ans+=l[2*a[j]-a[k]];
l[a[j]]++;
}
}
cout<<ans<<'\n';
return 0;
}