BZOJ3509-FFT,分块

BZOJ3509

题目描述

给定一个长度为 N N N的数组 A [ ] A[] A[],求有多少对 i , j , k ( 1 < = i < j < k < = N i, j, k(1<=i< j< k<=N i,j,k1<=i<j<k<=N满足 A [ k ] − A [ j ] = A [ j ] − A [ i ] A[k]-A[j]=A[j]-A[i] A[k]A[j]=A[j]A[i]

题解

对题干式子移下项。
2 A [ j ] = A [ i ] + A [ k ] 2A[j]=A[i]+A[k] 2A[j]=A[i]+A[k]

考虑暴力。
枚举中间项,构造多项式 f k = ∑ i + j = k a i b j f_k=\sum\limits_{i+j=k}a_ib_j fk=i+j=kaibj对前半边和后半边做FFT,那么以当前位置为中间项的答案即为 f [ 2 ∗ A [ 当 前 项 ] ] f[2*A[当前项]] f[2A[]],时间复杂度即为 O ( n 2 l o g n ) O(n^2logn) O(n2logn)

考虑优化。
由于每次只求一项的值,所以考虑分块优化。
分两种情况讨论
1,当 i , k i,k i,k都不在 j j j所在的块时,FFT解决。
2,否则,暴力枚举统计答案。
显然当块的大小为 n l o g n = 2000 \sqrt{nlogn}=2000 nlogn =2000时,复杂度最优,为 O ( ( n l o g n ) 3 2 ) O((nlogn)^{\frac{3}{2}}) O((nlogn)23)

代码

O ( n 2 l o g n ) O(n^2logn) O(n2logn)

#include<bits/stdc++.h>
#define ll long long
#define int long long
#define M 100009
using namespace std;
int read(){
	int f=1,re=0;char ch;
	for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
	if(ch=='-'){f=-1,ch=getchar();}
	for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
	return re*f;
}
const int g=3;
const int mod=998244353; 
int r[M];
int ksm(int a,int b){//快速幂 
	int ans=1;
	while(b){
		if(b&1) ans=(ll)ans*a%mod;
		a=(ll)a*a%mod;
		b>>=1;
	}return ans%mod;
}
void ntt(int *A,int lim,int type){//ntt
	for(int i=0;i<lim;i++) if(i<r[i]) swap(A[i],A[r[i]]);
	for(int mid=1;mid<lim;mid<<=1){
		int W=ksm(g,(mod-1)/(mid<<1));
		for(int R=mid<<1,j=0;j<lim;j+=R){
			int w=1;
			for(ll k=0;k<mid;k++,w=(ll)w*W%mod){
				int x=A[j+k],y=(ll)w*A[j+k+mid]%mod;
				A[j+k]=(x+y)%mod;
				A[j+mid+k]=(x-y+mod)%mod;
			}
		}
	}
	if(type==-1){
		reverse(A+1,A+lim);
        int inv=ksm(lim,mod-2);
        for(int i=0;i<lim;i++) A[i]=(ll)A[i]*inv%mod;
	}
}
int a[M],b[M],c[M],d[M],n,maxn,ans;
signed main(){
	n=read();
	for(int i=1;i<=n;i++) a[i]=read(),maxn=max(maxn,a[i]);
	int lim=1,l=0;
	while(lim<maxn*2) lim<<=1,l++;
	for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	//printf("%d %d\n",maxn,lim);
	for(int i=2;i<n;i++){
		memset(b,0,sizeof(b));
		memset(c,0,sizeof(c));
		memset(d,0,sizeof(d));
		for(int j=1;j<i;j++) b[a[j]]++;
		for(int j=i+1;j<=n;j++) c[a[j]]++;
//		for(int j=0;j<lim;j++) printf("%lld ",b[j]);
//		printf("\n");
//		for(int j=0;j<lim;j++) printf("%lld ",c[j]);
//		printf("\n");
		ntt(b,lim,1),ntt(c,lim,1);
		for(int j=0;j<lim;j++) d[j]=(ll)b[j]*c[j];
		ntt(d,lim,-1);
		ans+=d[2*a[i]];
	}printf("%lld\n",ans);
	return 0;
}

正解

#include<bits/stdc++.h>
#define ll long long
#define int long long
#define M 100009
using namespace std;
int read(){
	int f=1,re=0;char ch;
	for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
	if(ch=='-'){f=-1,ch=getchar();}
	for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
	return re*f;
}
const int g=3;
const int mod=998244353; 
int rev[M];
int ksm(int a,int b){//快速幂 
	int ans=1;
	while(b){
		if(b&1) ans=(ll)ans*a%mod;
		a=(ll)a*a%mod;
		b>>=1;
	}return ans%mod;
}
void ntt(int *A,int lim,int type){//ntt
	for(int i=0;i<lim;i++) if(i<rev[i]) swap(A[i],A[rev[i]]);
	for(int mid=1;mid<lim;mid<<=1){
		int W=ksm(g,(mod-1)/(mid<<1));
		for(int R=mid<<1,j=0;j<lim;j+=R){
			int w=1;
			for(ll k=0;k<mid;k++,w=(ll)w*W%mod){
				int x=A[j+k],y=(ll)w*A[j+k+mid]%mod;
				A[j+k]=(x+y)%mod;
				A[j+mid+k]=(x-y+mod)%mod;
			}
		}
	}
	if(type==-1){
		reverse(A+1,A+lim);
        int inv=ksm(lim,mod-2);
        for(int i=0;i<lim;i++) A[i]=(ll)A[i]*inv%mod;
	}
}
int a[M],b[M],c[M],n,maxn,ans,num,L[M],R[M],l[M],r[M],block;
signed main(){
	n=read();block=2000;
	for(int i=1;i<=n;i++) a[i]=read(),maxn=max(maxn,a[i]),r[a[i]]++;
	
	int lim=1,lll=0;
	while(lim<maxn*2) lim<<=1,lll++;
	for(int i=0;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lll-1));
	
	num=n/block;if(n%block) num++;
    for(int i=1;i<=num;i++) L[i]=(i-1)*block+1,R[i]=i*block;R[num]=n;
    
	for(int i=1;i<=num;i++){
		for(int j=L[i];j<=R[i];j++) r[a[j]]--;
		for(int j=0;j<lim;j++) b[j]=c[j]=0;
		for(int j=0;j<=maxn;j++) b[j]=l[j],c[j]=r[j];
		ntt(b,lim,1),ntt(c,lim,1);
		for(int j=0;j<lim;j++) b[j]=(ll)b[j]*c[j];
		ntt(b,lim,-1);
		for(int j=L[i];j<=R[i];j++) ans+=(ll)b[2*a[j]];
		for(int j=L[i];j<=R[i];j++){
			for(int k=L[i];k<j;k++) if(2*a[j]-a[k]>0) ans+=(ll)r[2*a[j]-a[k]];
			for(int k=j+1;k<=R[i];k++) if(2*a[j]-a[k]>0) ans+=(ll)l[2*a[j]-a[k]];
			l[a[j]]++;
		}
	}printf("%lld\n",ans);
	return 0;
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值