学习的cxlove大神的博客:传送门
关键点:枚举中间的数,如果其他两个数不在当前块中,那么前面的所有的块和后面所有的块做卷积,得到前面序列和后面序列相乘的情况。如果有这两个数有数在当前块,那么枚举前一个数在当前块,后一个数不在当前块,枚举后一个数在当前块,前一个数在当前块或者不在当前块,都可以。
分块的关键:是复杂度降到 o(k∗(Nk∗Nk+M∗logM) 这里不能按照 sqrt(N) 来分块,复杂度过大。
代码:
#include <bits/stdc++.h>
#define LL long long
#define FOR(i,x,y) for(int i = x;i < y;++ i)
#define IFOR(i,x,y) for(int i = x;i > y;-- i)
using namespace std;
//FFT copy from kuangbin
const double pi = acos (-1.0);
// Complex z = a + b * i
struct Complex {
double a, b;
Complex(double _a=0.0,double _b=0.0):a(_a),b(_b){}
Complex operator + (const Complex &rhs) const {
return Complex(a + rhs.a , b + rhs.b);
}
Complex operator - (const Complex &rhs) const {
return Complex(a - rhs.a , b - rhs.b);
}
Complex operator * (const Complex &rhs) const {
return Complex(a * rhs.a - b * rhs.b , a * rhs.b + b * rhs.a);
}
};
//len = 2 ^ k
void change (Complex y[] , int len) {
for (int i = 1 , j = len / 2 ; i < len -1 ; i ++) {
if (i < j) swap(y[i] , y[j]);
int k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if(j < k) j += k;
}
}
// FFT
// len = 2 ^ k
// on = 1 DFT on = -1 IDFT
void FFT (Complex y[], int len , int on) {
change (y , len);
for (int h = 2 ; h <= len ; h <<= 1) {
Complex wn(cos (-on * 2 * pi / h), sin (-on * 2 * pi / h));
for (int j = 0 ; j < len ; j += h) {
Complex w(1 , 0);
for (int k = j ; k < j + h / 2 ; k ++) {
Complex u = y[k];
Complex t = w * y [k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on == -1) {
for (int i = 0 ; i < len ; i ++) {
y[i].a /= len;
}
}
}
const int maxn = 100010;
const int maxm = 30030;
int l[maxm<<2],r[maxm<<2];
LL num[maxm<<2];
int a[maxn],n,len;
void init(){
memset(l,0,sizeof(l));
memset(r,0,sizeof(r));
int mx_len = -1;
FOR(i,0,n){
scanf("%d",&a[i]);
++ r[a[i]];
mx_len = max(mx_len,a[i]);
}
mx_len = (mx_len+1) << 1;
len = 1;
while(len < mx_len) len <<= 1;
}
Complex x1[maxm<<2],x2[maxm<<2],x[maxm<<2];
void calc(){
FOR(i,0,len) x1[i] = Complex(l[i],0);//偷懒写法,这个时候l[maxm<<2],这里一直re,其实不应该偷懒的
FOR(i,0,len) x2[i] = Complex(r[i],0);
FFT(x1,len,1);
FFT(x2,len,1);
FOR(i,0,len) x[i] = x1[i]*x2[i];
FFT(x,len,-1);
FOR(i,0,len) num[i] = (LL)(x[i].a+0.5);
}
void work(){
LL ans = 0;
int blocks = min(n,30);
int blk = (n+blocks-1)/blocks;
FOR(i,0,blocks){
int s = i*blk,t = min(n,(i+1)*blk);
FOR(j,s,t) -- r[a[j]];
calc();
FOR(j,s,t){
int y = (a[j]<<1);
ans += num[y];
FOR(k,s,j) if(y >= a[k] && y-a[k] < maxm)
ans += r[y-a[k]];
FOR(k,j+1,t) if(y >= a[k] && y-a[k] < maxm)
ans += l[y-a[k]];
++ l[a[j]];
}
}
printf("%lld\n",ans);
}
int main(){
//freopen("test.in","r",stdin);
while(~scanf("%d",&n)){
init();
work();
}
}