CodeChef - COUNTARI Arithmetic Progressions (FFT+分块)

28 篇文章 0 订阅

题目链接:https://www.codechef.com/problems/COUNTARI

题目大意:给出一个长度为n的数组,要求从中选出三个数a[i],a[j],a[k],满足i<j<k,且a[j]-a[i]=a[k]-a[j]。问能选出多少个这样的三元组。

题目思路:由于取出的三元组为等差数列,如果没有i < j < k的限制,我们就可以直接枚举中间那个数 a[j] 的值,借助FFT求出有多少二元组的和为2*a[j]的数量即可。但现在加入了i < j < k之后,我们就无法直接用枚举中间数的方法了,因为顺序无法保证满足i < j < k。在参考了大佬的博客之后,学会了如何用分块来解决这个问题。

我们先将这n个数分为block块,每一块的数的个数为n/block个(也可能不足)。接着对可能为答案的三元组进行分类讨论。

第一类为a[i]处于前面的块内,a[j]和a[k]处于当前的块内;

第二类为a[i]、a[j]、a[k]都处于当前的块内;

第三类为a[i]和a[j]处于当前块内,a[k]处于后面的块内;

第四类为a[i]处于前面的块内,a[j]处于当前的块内,a[k]处于后面的块内。

我们用Front[i] 表示前面的块内值为 i 的数量,Now[i] 表示当前的块内值为 i 的数量,Behind[i] 表示后面的块内值为 i 的数量。

对于前三种情况,我们可以枚举其中的两个值,再求第三个值的数量,复杂度为O(n/block * n/block),对于第四种情况来说,我们就可以将前面的块的值和后面的块的值做一次FFT,再枚举中间值a[j]即可计算出答案,复杂度为O(len * log(len))。

总体的时间复杂度为O(block * (n/block * n/block + len * log(len))。(虽然不太懂这个复杂度,但真的很快。。。)

具体实现看代码:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lowbit(x) x&-x
#define pb push_back
#define MP make_pair
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define fuck(x) cout<<"["<<x<<"]"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int>pii;
//head
const int MX = 1e5+5;

const double pi = acos(-1.0);
int len,mx;//开大4倍
struct Complex {
    double r,i;
    Complex(double r=0,double i=0):r(r),i(i) {};
    Complex operator+(const Complex &rhs) {return Complex(r + rhs.r,i + rhs.i);}
    Complex operator-(const Complex &rhs) {return Complex(r - rhs.r,i - rhs.i);}
    Complex operator*(const Complex &rhs) {return Complex(r*rhs.r - i*rhs.i,i*rhs.r + r*rhs.i);}
} va[(1<<16)],vb[(1<<16)];
void rader(Complex F[],int len) { //len = 2^M,reverse F[i] with  F[j] j为i二进制反转
    int j = len >> 1;
    for(int i = 1; i < len - 1; ++i) {
        if(i < j) swap(F[i],F[j]);  // reverse
        int k = len>>1;
        while(j>=k) {
            j -= k;
            k >>= 1;
        }
        if(j < k) j += k;
    }
}
void FFT(Complex F[],int len,int t) {
    rader(F,len);
    for(int h=2; h<=len; h<<=1) {
        Complex wn(cos(-t*2*pi/h),sin(-t*2*pi/h));
        for(int j=0; j<len; j+=h) {
            Complex E(1,0); //旋转因子
            for(int k=j; k<j+h/2; ++k) {
                Complex u = F[k];
                Complex v = E*F[k+h/2];
                F[k] = u+v;
                F[k+h/2] = u-v;
                E=E*wn;
            }
        }
    }
    if(t==-1)   //IDFT
        for(int i=0; i<len; ++i)
            F[i].r/=len;
}
void Conv(Complex a[],Complex b[],int len) { //求卷积
    FFT(a,len,1);
    FFT(b,len,1);
    for(int i=0; i<len; ++i) a[i] = a[i]*b[i];
    FFT(a,len,-1);
}

int n;
int a[MX];
int Front[MX],Now[MX],Behind[MX];

void solve(){
    scanf("%d",&n);
    mx = 1;
    for(int i = 1;i <= n;i++){
        scanf("%d",&a[i]);
        Behind[a[i]]++;
        mx = max(mx,a[i]);
    }
    len = 1;
    while(len <= mx) len<<=1;
    len<<=1;
    int block = min(n,45);
    int sz = n/block;
    if(n%block != 0) sz++;
    ll ans = 0;
    for(int pos = 1;pos <= block;pos++){
        int s = sz*(pos - 1) + 1, t = min(n,sz*pos);
        for(int i = s;i <= t;i++) Behind[a[i]]--;
        for(int i = s;i <= t;i++){
            for(int j = i + 1;j <= t;j++){
                int cnt = 2*a[i] - a[j];
                if(cnt>=1 && cnt <= mx){
                    ans += Now[cnt]; ans += Front[cnt];
                }

                cnt = 2*a[j] - a[i];
                if(cnt>=1 && cnt <= mx){
                    ans += Behind[cnt];
                }
            }
            Now[a[i]]++;
        }

        for(int i = 0;i < len;i++){
            if(i <= mx){
                va[i] = Complex(Front[i],0);
                vb[i] = Complex(Behind[i],0);
            }else{
                va[i] = Complex(0,0);
                vb[i] = Complex(0,0);
            }
        }
        Conv(va,vb,len);
        for(int i = s;i <= t;i++){
            int cnt = 2*a[i];
            ans += (ll)(va[cnt].r+0.5);
        }
        for(int i = s;i <= t;i++){
            Front[a[i]]++; Now[a[i]]--;
        }
    }
    printf("%lld\n",ans);
}

int main(){
    //FIN;
    solve();
    return 0;
}

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值