题目链接:【CodeChef COUNTARI】Arithmetic Progressions
题目大意:给定一个长度为 n n 的数列,求数列中有多少个三元组 ,满足:
- 1≤i<j<k≤n 1 ≤ i < j < k ≤ n
- ai−aj=aj−ak a i − a j = a j − a k
n≤100000 n ≤ 100000 , m=max{ai}≤30000 m = m a x { a i } ≤ 30000 。
式子化为: 2aj=ai+ak 2 a j = a i + a k 。考虑枚举 j j 的位置,再将数列前半部分的生成函数与后半部分的生成函数做卷积,即可得到有多少对 ,使得 2aj=ai+ak 2 a j = a i + a k 。这样的时间复杂度是 Θ(nmlogm) Θ ( n m log m ) 的,会超时。考虑如何减少卷积次数。
考虑将数组分成 blocks b l o c k s 块, (i,j,k) ( i , j , k ) 的位置有三种情况,我们分别计算即可。
- 情况 1 1 : 在同一块内。在块内部枚举其中两个数的位置,就不难得出第三个数的值,用一个数组维护一下即可。时间复杂度 Θ(blocks⋅(nblocks)2)=Θ(n2blocks) Θ ( b l o c k s ⋅ ( n b l o c k s ) 2 ) = Θ ( n 2 b l o c k s ) 。
- 情况 2 2 : 有两个数在同一块内,而另一个数在另一块内。与刚才思路类似,枚举同一块内的两个数即可。时间复杂度 Θ(n2blocks) Θ ( n 2 b l o c k s ) 。
- 情况 3 3 : 所在的块两两不同。这样的话,类似一开始的思路,枚举 j j 的位置,再做卷积。不同的是,卷积次数降低到了 。
这时,取一个合适的 blocks b l o c k s 即可。
#include <cmath>
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const int maxn = 100005;
const int block = 2222;
const double pi = acos(-1.);
struct cd {
double r, i;
cd() {}
cd(double real, double imag) {
r = real, i = imag;
}
double& real() {
return r;
}
cd operator+(const cd &x) const{
return cd(r + x.r, i + x.i);
}
cd operator-(const cd &x) const{
return cd(r - x.r, i - x.i);
}
cd operator*(const cd &x) const{
return cd(r * x.r - i * x.i, r * x.i + i * x.r);
}
};
ll ans;
int n, m, bit, lim, a[maxn], r[maxn];
int cnt[maxn], pre[maxn], nxt[maxn];
cd f[maxn], g[maxn];
void fft(cd *a, int dft) {
for (int i = 0; i < lim; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < lim; k <<= 1) {
cd wn0(cos(pi / k), dft * sin(pi / k));
for (int i = 0; i < lim; i += k << 1) {
cd wnk(1, 0);
for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
cd x = a[j], y = wnk * a[j + k];
a[j] = x + y, a[j + k] = x - y;
}
}
}
if (dft == -1) {
for (int i = 0; i < lim; i++) {
a[i].real() /= lim;
}
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
m = max(m, a[i]);
}
for (lim = 1; lim <= m << 1; lim <<= 1) bit++;
for (int i = 0; i < lim; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
for (int i = 1; i <= n; i++) {
nxt[a[i]]++;
}
for (int r, l = 1; l <= n; l += block) {
r = min(n, l + block - 1);
for (int i = l; i <= r; i++) {
nxt[a[i]]--;
}
// Type I & II
for (int i = l; i <= r; i++) {
for (int j = i + 1; j <= r; j++) {
int k = 2 * a[i] - a[j];
if (1 <= k && k <= m) {
ans += cnt[k] + pre[k];
}
k = 2 * a[j] - a[i];
if (1 <= k && k <= m) {
ans += nxt[k];
}
}
cnt[a[i]]++;
}
// Type III
for (int i = 0; i <= m; i++) {
f[i] = cd(pre[i], 0), g[i] = cd(nxt[i], 0);
}
for (int i = m + 1; i < lim; i++) {
f[i] = g[i] = cd(0, 0);
}
fft(f, 1), fft(g, 1);
for (int i = 0; i < lim; i++) {
f[i] = f[i] * g[i];
}
fft(f, -1);
for (int i = l; i <= r; i++) {
ans += ll(f[2 * a[i]].real() + 0.5);
pre[a[i]]++, cnt[a[i]]--;
}
}
printf("%lld\n", ans);
return 0;
}