Description
给定一个长度为N的数组A[],求有多少对i, j, k(1<=i
Input
第一行一个整数N(N<=10^5)。
接下来一行N个数A[i](A[i]<=30000)。
Output
一行一个整数。
Sample Input
10
3 5 3 6 3 4 10 4 5 2
Sample Output
9
HINT
进行化简可知题目要求即 a[k]+a[i]=2a[j],这个式子我们很容易想到可以暴力,但是显然这个复杂度是不够优秀的,但是我们注意到如果我们对于一个区间[l,r]考虑i,k均不在[l,r]中,j在[l,r]中的时候的答案可以发现其实是可以fft的,对于j,k在[l,r]中的时候可以用暴力来解决,这样子的话复杂度就可以被优化到N/B*VlogV+B^2。
CODE
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const double pi = acos(-1);
const int N = 100005;
const int M = 30005;
const int B = 3560;
int read()
{
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
struct com
{
double x,y;
com operator + (const com &a) const {return (com){x + a.x, y + a.y};}
com operator - (const com &a) const {return (com){x - a.x, y - a.y};}
com operator * (const com &a) const {return (com){x * a.x - y * a.y, x * a.y + y * a.x};}
com operator / (const double &a) const {return (com){x / a, y / a};}
}a[M * 4], b[M * 4];
int L;
int rev[M * 4];
void fft(com *a, int f)
{
for (int i = 0; i < L; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int i = 1; i < L; i <<= 1)
{
com wn = (com){cos(pi / i), f * sin(pi / i)};
for (int j = 0; j < L; j += (i << 1))
{
com w = (com){1,0};
for (int k = 0; k < i; k++)
{
com u = a[j + k], v = a[j + k + i] * w;
a[j + k] = u + v, a[j + k + i] = u - v;
w = w * wn;
}
}
}
if (f == -1)
for (int i = 0; i < L; i++)
a[i] = a[i] / L;
}
int suf[M * 2], pre[M * 2];
int V;
int bel[N],sta[M],end[M];
int w[N];
int main()
{
int n = read();
for (int i = 1; i <= n; i++)
{
w[i] = read();
bel[i] = (i + B - 1) / B;
suf[w[i]]++;
V = std::max(V, w[i]);
if (!sta[bel[i]])
sta[bel[i]] = i;
end[bel[i]] = i;
}
int lg = 0;
for (L = 1; L <= V * 2; L <<= 1, lg++);
for (int i = 0; i < L; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
ll ans = 0;
for (int i = 1; i <= bel[n]; i++)
{
int l = sta[i], r = end[i];
for (int j = l; j <= r; j++)
suf[w[j]]--;
for (int j = 0; j < L; j++)
a[j] = b[j] = (com){0,0};
for (int j = 0; j <= V; j++)
a[j] = (com){pre[j], 0}, b[j] = (com){suf[j],0};
fft(a,1);
fft(b,1);
for (int j = 0; j < L; j++)
a[j] = a[j] * b[j];
fft(a,-1);
for (int j = l; j <= r; j++)
ans += (ll)(a[w[j] * 2].x + 0.1);
for (int j = l; j <= r; j++)
{
for (int k = j + 1; k <= r; k++)
ans += (w[j] * 2 - w[k] >= 0) ? pre[w[j] * 2 - w[k]] : 0;
pre[w[j]]++;
}
for (int j = l + 1; j <= r; j++)
for (int k = l; k < j; k++)
if (w[j] * 2 - w[k] >= 0)
ans += suf[w[j] * 2 - w[k]];
}
printf("%lld\n",ans);
}