Description
给定一个长度为
N
的数组
Input
第一行一个整数
N
。
接下来一行
Output
一行一个整数。
Sample Input
10
3 5 3 6 3 4 10 4 5 2
Sample Output
9
HINT
N≤105,Ai≤30000
题解
以下记
W=max{A1,…,AN}
。
这道题的
O(NW)
做法非常显然,即维护每个数左边和右边每种数字出现的次数。但这也是这道题的瓶颈所在。因为很显然这个算法是很难(或者不可能)继续优化下去的,所以很可能会卡在这里(如果你之前没做过类似的题目)。
这样,我们就考虑从一个看起来时间复杂度更坏的算法入手。
由于三个数构成等差数列,所以
2Aj=Ai+Ak
。我们可以对于每一个数维护左边和右边每种数字出现的次数,这个可以做到
O(N)
。然后统计方案数可以用卷积来实现,用 FFT 可以做到
O(Wlog2W)
,于是总复杂度为
O(NWlog2W)
。很明显是更差的。但是这个算法就少了很多局限性。
考虑我们卷积的过程,设多项式
f(x)
表示下标在区间
[L1,R1]
的生成函数(
xk
的系数表示数字
k
出现的次数);
这样就很明显可以分块来做。
我们把整个区间分成
K
块,枚举每一个数为中项,我们讨论下列三种情况:
1.首项和末项都在块内,可以用刚开始的做法,但是如果块内元素比较小,可以枚举首项下标,这样单块复杂度
2.首项和末项有一个在块内,我们可以枚举在块内的那一项,同样可以做到单块
O((NK)2)
。
3.首项和末项都不在块内,那么我们就需要用卷积了。一次卷积即可求出块内所有元素为中项的方案数。单块复杂度
O(Wlog2W)
。
那么总的复杂度就是
O(N2K+KWlog2W)
。
由均值不等式,
K=NWlogW√
时复杂度最低,为
O(NWlog2W−−−−−−−√)
。
但是事实上,由于常数等原因,块的大小需要调大大约 10 倍,约 2000 左右时最快。由于此题卡常严重,需要手写复数类。
My Code
/**************************************************************
Problem: 3509
User: infinityedge
Language: C++
Result: Accepted
Time:34664 ms
Memory:5784 kb
****************************************************************/
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <complex>
#define MAXN 65536
#define pi acos(-1)
using namespace std;
typedef long long ll;
struct E{
long double real, imag;
E(long double real = 0, long double imag = 0) : real(real), imag(imag) { }
inline friend E operator + (E &a, E &b)
{ return E(a.real + b.real, a.imag + b.imag); }
inline friend E operator - (E &a, E &b)
{ return E(a.real - b.real, a.imag - b.imag); }
inline friend E operator * (E &a, E &b)
{ return E(a.real * b.real - a.imag * b.imag , a.imag * b.real + a.real * b.imag); }
inline friend void swap(E &a, E &b)
{ E c = a; a = b; b = c; }
};
E a[MAXN + 1], b[MAXN + 1];
void bit_reverse(int n, E* r){
for(int i = 0, j = 0; i < n; i ++){
if(i > j) swap(r[i], r[j]);
for(int l = n >> 1; (j ^= l) < l; l >>= 1);
}
}
void fft(int n, E* r, int f){
bit_reverse(n, r);
for(int i = 2; i <= n; i <<= 1){
int m = i >> 1;
for(int j = 0; j < n; j += i){
E w(1, 0), wn(cos(2 * pi / i), f * sin(2 * pi / i));
for(int k = 0; k < m; k ++){
E z = r[j + m + k] * w;
r[j + m + k] = r[j + k] - z;
r[j + k] = r[j + k] + z;
w = w * wn;
}
}
}
if(f == -1){
E ww = E(1.0 / n, 0);
for(int i = 0; i < n; i ++) r[i] = r[i] * ww;
}
}
int n, k, m;
int d[100005], pos[100005], l[1005], r[1005];
ll ans;
int vis[30005];
int tmpl[MAXN], tmpr[MAXN];
void solve(int x){
for(int i = l[x]; i <= r[x]; i ++){
tmpr[d[i]]++;
}
for(int i = l[x]; i <= r[x]; i ++){
for(int j = i + 1; j <= r[x]; j ++){
int dk = d[i] + d[i] - d[j];
ans += tmpl[dk];
}
tmpl[d[i]]++;
}
for(int i = l[x]; i <= r[x]; i ++){
tmpr[d[i]] = tmpl[d[i]] = 0;
}
}
int N = 1;
void solsub(int x){
for(int i = l[x]; i <= r[x]; i ++){
for(int j = i + 1; j <= r[x]; j ++){
int dk = d[i] + d[i] - d[j];
if(dk >= 0) ans += tmpl[dk];
dk = d[j] + d[j] - d[i];
if(dk >= 0) ans += tmpr[dk];
}
}
if(x == 1 || x == m) return;
for(int i = 0; i <= N; i ++){
a[i] = b[i] = E(0, 0);
}
for(int i = 0; i <= N; i ++){
a[i] = E(tmpl[i], 0);
b[i] = E(tmpr[i], 0);
}
fft(N, a, 1); fft(N, b, 1);
for(int i = 0; i <= N; i ++){
a[i] = a[i] * b[i];
}
fft(N, a, -1);
for(int i = l[x]; i <= r[x]; i ++){
ans = ans + (ll)(a[2 * d[i]].real + 0.1);
}
}
void solve2(){
int mx = 0;
for(int i = 1; i <= n; i ++){
mx = max(d[i], mx);
}
mx = mx * 2 + 1;
while(N < mx) N = N << 1;
for(int i = 1; i <= n; i ++){
tmpr[d[i]] ++;
}
for(int i = 1; i <= m; i ++){
for(int j = l[i]; j <= r[i]; j ++){
tmpr[d[j]] --;
}
solsub(i);
for(int j = l[i]; j <= r[i]; j ++){
tmpl[d[j]] ++;
}
}
}
int main(){
scanf("%d", &n); k = 1823;
if(n < 1823) k = 1823;
for(int i = 1; i <= n; i ++){
scanf("%d", &d[i]);
}
for(int i = 1; i <= n; i ++){
pos[i] = (n - 1) / k + 1;
}
m = pos[n];
for(int i = 1; i <= m; i ++){
l[i] = (i - 1) * k + 1;
r[i] = i * k;
}
r[m] = n;
for(int i = 1; i <= m; i ++){
solve(i);
}
solve2();
printf("%lld\n", ans);
return 0;
}