题意:
给定一个长度为n(n<=1e5)的序列,询问从中能选出多少个长度为3的等差子序列。
序列中的每个元素 (1<=a<=3e4)
思路:
因为子序列是有序的,所以并不能直接用FFT求的答案
想了想好像也没有更好的写法,只能暴力的分块一下。
每个块需要统计一下几种情况:设有k个块 , m = (n/k)
①:选中序列中的一个数在块里,那么可以假设选中的第一个或者第二个或者第三个在块里,统计另外两个的情况,发现如果枚举第二个在块内,然后让块前的数列和块后的数列进行FFT运算就能在 m + L *log(L) = (1e6 + m) 的时间内统计出结果
注:L为FFT的长度,根据题意应为(1<<16)
②:选中序列中的两个数在块里,那么需要分别统计 选中的前两个(第一个和第二个)和后两个(第二个和第三个)在块里时的结果,因为是有序的遍历每个块,对于每种情况可以O(1)的得到结果,那么就可以在O(m*m)的时间内统计出结果
③:选中序列中的三个数全部在块里,那么需要枚举前两个数(或者后两个数)的位置,统计块内枚举位置的后面(后两个数:前面)对应的数出现的个数,这个利用类似于双指针的方法也可以对各种情况在O(1)内得到结果,那么也可以在O(m*m)的时间内统计出结果
那么总的复杂为 O( k * ( 2 * m * m + 1e6 + m ) ) , k * m 约等于 n ,那么 可以计算出k的一个合适的值使得复杂度最小,这里我选的 (k=100)使得复杂度在O(1e8)内。。虽然只给了3s,但是OJ运算速度很给力。。1.5s过了。。
代码:
#include<bits/stdc++.h>
const double PI = acos(-1.0);
using namespace std;
struct comple{
double r,i;
comple (double rr=0,double ii=0){
r = rr; i = ii;
}comple operator+(const comple& a){
return comple(r+a.r,i+a.i);
}comple operator-(const comple& a){
return comple(r-a.r,i-a.i);
}comple operator*(const comple& a){
return comple(r*a.r-i*a.i,r*a.i+i*a.r);
}
};
void brc(comple *a,int l){
for(int i=1,j=l/2;i<l-1;i++){
if(i<j)swap(a[i],a[j]);
int k = l/2;
while(j>=k){
j-=k;
k>>=1;
}if(j<k)j+=k;
}
}
void fft(comple*y,int l,int on){
brc(y,l);
comple u,t;
for(int h=2;h<=l;h<<=1){
comple wn(cos(on*2*PI/h),sin(on*2*PI/h));
for(int j=0;j<l;j+=h){
comple w(1,0);
for(int k=j;k<j+h/2;k++){
u = y[k];
t = w*y[k+h/2];
y[k] = u+t;
y[k+h/2] = u-t;
w = w*wn;
}
}
}if(on<0){
for(int i=0;i<l;i++){
y[i].r /= l;
}
}
}
const int MAXN = (1<<16)+10;
int n,m,len;
int A[100005];
int L[105];
int R[105];
int las[30005];
int nex[30005];
int now[30005];
long long cnt[MAXN];
comple F[2][MAXN];
long long ans = 0;
void oper(int l,int r){
int x = 2*A[r]-A[l];
if(x>=1&&x<=30000)ans+=now[x]+nex[x];
x = 2*A[l]-A[r];
if(x>=1&&x<=30000)ans+=las[x];
}
int main()
{
//freopen("data.in","r",stdin);
scanf("%d",&n);
if(n<=300)m=1;
else m = 100;
len = ( n + m-1 ) / m;
for(int i=1;i<=n;i++){
scanf("%d",&A[i]);
nex[A[i]]++;
}
for(int l=1,r=min(len,n),i=1;i<=m;i++){
L[i] = l;R[i] = r;
l = r+1;r = min(n,r+len);
}
for(int pos = 1; pos <= m ; pos++ ){
int l = L[pos] , r = R[pos] , ll = (1<<16);
for(int i=l;i<=r;i++)nex[A[i]]--;
for(int i=0;i<=30000;i++){
now[i]=0;
F[0][i].r = las[i];
F[1][i].r = nex[i];
}fft(F[0],ll,1);fft(F[1],ll,1);
for(int i=0;i<ll;i++){
F[0][i] = F[0][i] * F[1][i];
}fft(F[0],ll,-1);
for(int i=0;i<ll;i++){
cnt[i] = (long long)(F[0][i].r+0.5);
F[0][i].i = F[1][i].i = F[0][i].r = F[1][i].r = 0 ;
}
do{
ans += cnt[A[l]*2];
while(r>l){
oper(l,r);
now[A[r]]++;
r--;
}l++;
if(l>R[pos])break;
now[A[++r]]--;
ans += cnt[A[l]*2];
while(++r<=R[pos]){
now[A[r]]--;
oper(l,r);
}r--;l++;
}while(l<=R[pos]);
for(int i=L[pos];i<=R[pos];i++){
las[A[i]]++;
}
}printf("%lld\n",ans);
return 0;
}