BZOJ3509
题目描述
给定一个长度为 N N N的数组 A [ ] A[] A[],求有多少对 i , j , k ( 1 < = i < j < k < = N i, j, k(1<=i< j< k<=N i,j,k(1<=i<j<k<=N满足 A [ k ] − A [ j ] = A [ j ] − A [ i ] A[k]-A[j]=A[j]-A[i] A[k]−A[j]=A[j]−A[i]。
题解
对题干式子移下项。
2
A
[
j
]
=
A
[
i
]
+
A
[
k
]
2A[j]=A[i]+A[k]
2A[j]=A[i]+A[k]
考虑暴力。
枚举中间项,构造多项式
f
k
=
∑
i
+
j
=
k
a
i
b
j
f_k=\sum\limits_{i+j=k}a_ib_j
fk=i+j=k∑aibj对前半边和后半边做FFT,那么以当前位置为中间项的答案即为
f
[
2
∗
A
[
当
前
项
]
]
f[2*A[当前项]]
f[2∗A[当前项]],时间复杂度即为
O
(
n
2
l
o
g
n
)
O(n^2logn)
O(n2logn)
考虑优化。
由于每次只求一项的值,所以考虑分块优化。
分两种情况讨论
1,当
i
,
k
i,k
i,k都不在
j
j
j所在的块时,FFT解决。
2,否则,暴力枚举统计答案。
显然当块的大小为
n
l
o
g
n
=
2000
\sqrt{nlogn}=2000
nlogn=2000时,复杂度最优,为
O
(
(
n
l
o
g
n
)
3
2
)
O((nlogn)^{\frac{3}{2}})
O((nlogn)23)
代码
O ( n 2 l o g n ) O(n^2logn) O(n2logn)
#include<bits/stdc++.h>
#define ll long long
#define int long long
#define M 100009
using namespace std;
int read(){
int f=1,re=0;char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-'){f=-1,ch=getchar();}
for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
const int g=3;
const int mod=998244353;
int r[M];
int ksm(int a,int b){//快速幂
int ans=1;
while(b){
if(b&1) ans=(ll)ans*a%mod;
a=(ll)a*a%mod;
b>>=1;
}return ans%mod;
}
void ntt(int *A,int lim,int type){//ntt
for(int i=0;i<lim;i++) if(i<r[i]) swap(A[i],A[r[i]]);
for(int mid=1;mid<lim;mid<<=1){
int W=ksm(g,(mod-1)/(mid<<1));
for(int R=mid<<1,j=0;j<lim;j+=R){
int w=1;
for(ll k=0;k<mid;k++,w=(ll)w*W%mod){
int x=A[j+k],y=(ll)w*A[j+k+mid]%mod;
A[j+k]=(x+y)%mod;
A[j+mid+k]=(x-y+mod)%mod;
}
}
}
if(type==-1){
reverse(A+1,A+lim);
int inv=ksm(lim,mod-2);
for(int i=0;i<lim;i++) A[i]=(ll)A[i]*inv%mod;
}
}
int a[M],b[M],c[M],d[M],n,maxn,ans;
signed main(){
n=read();
for(int i=1;i<=n;i++) a[i]=read(),maxn=max(maxn,a[i]);
int lim=1,l=0;
while(lim<maxn*2) lim<<=1,l++;
for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
//printf("%d %d\n",maxn,lim);
for(int i=2;i<n;i++){
memset(b,0,sizeof(b));
memset(c,0,sizeof(c));
memset(d,0,sizeof(d));
for(int j=1;j<i;j++) b[a[j]]++;
for(int j=i+1;j<=n;j++) c[a[j]]++;
// for(int j=0;j<lim;j++) printf("%lld ",b[j]);
// printf("\n");
// for(int j=0;j<lim;j++) printf("%lld ",c[j]);
// printf("\n");
ntt(b,lim,1),ntt(c,lim,1);
for(int j=0;j<lim;j++) d[j]=(ll)b[j]*c[j];
ntt(d,lim,-1);
ans+=d[2*a[i]];
}printf("%lld\n",ans);
return 0;
}
正解
#include<bits/stdc++.h>
#define ll long long
#define int long long
#define M 100009
using namespace std;
int read(){
int f=1,re=0;char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-'){f=-1,ch=getchar();}
for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
const int g=3;
const int mod=998244353;
int rev[M];
int ksm(int a,int b){//快速幂
int ans=1;
while(b){
if(b&1) ans=(ll)ans*a%mod;
a=(ll)a*a%mod;
b>>=1;
}return ans%mod;
}
void ntt(int *A,int lim,int type){//ntt
for(int i=0;i<lim;i++) if(i<rev[i]) swap(A[i],A[rev[i]]);
for(int mid=1;mid<lim;mid<<=1){
int W=ksm(g,(mod-1)/(mid<<1));
for(int R=mid<<1,j=0;j<lim;j+=R){
int w=1;
for(ll k=0;k<mid;k++,w=(ll)w*W%mod){
int x=A[j+k],y=(ll)w*A[j+k+mid]%mod;
A[j+k]=(x+y)%mod;
A[j+mid+k]=(x-y+mod)%mod;
}
}
}
if(type==-1){
reverse(A+1,A+lim);
int inv=ksm(lim,mod-2);
for(int i=0;i<lim;i++) A[i]=(ll)A[i]*inv%mod;
}
}
int a[M],b[M],c[M],n,maxn,ans,num,L[M],R[M],l[M],r[M],block;
signed main(){
n=read();block=2000;
for(int i=1;i<=n;i++) a[i]=read(),maxn=max(maxn,a[i]),r[a[i]]++;
int lim=1,lll=0;
while(lim<maxn*2) lim<<=1,lll++;
for(int i=0;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lll-1));
num=n/block;if(n%block) num++;
for(int i=1;i<=num;i++) L[i]=(i-1)*block+1,R[i]=i*block;R[num]=n;
for(int i=1;i<=num;i++){
for(int j=L[i];j<=R[i];j++) r[a[j]]--;
for(int j=0;j<lim;j++) b[j]=c[j]=0;
for(int j=0;j<=maxn;j++) b[j]=l[j],c[j]=r[j];
ntt(b,lim,1),ntt(c,lim,1);
for(int j=0;j<lim;j++) b[j]=(ll)b[j]*c[j];
ntt(b,lim,-1);
for(int j=L[i];j<=R[i];j++) ans+=(ll)b[2*a[j]];
for(int j=L[i];j<=R[i];j++){
for(int k=L[i];k<j;k++) if(2*a[j]-a[k]>0) ans+=(ll)r[2*a[j]-a[k]];
for(int k=j+1;k<=R[i];k++) if(2*a[j]-a[k]>0) ans+=(ll)l[2*a[j]-a[k]];
l[a[j]]++;
}
}printf("%lld\n",ans);
return 0;
}