Problem Description
给出一个长度为 n n n 的排列 p p p ,进行 Q Q Q 次查询。
每次查询给出一个区间 L , R L,R L,R ( 1 ≤ L ≤ R ≤ n 1 \le L \le R \le n 1≤L≤R≤n),询问有多少对 ( i , j i,j i,j) ( L ≤ i < j ≤ R L\le i < j \le R L≤i<j≤R)满足 p i + p j p_{i} + p_{j} pi+pj 是平方数。
Input
第一行包含一个整数 T T T ( T ≤ 5 T \le 5 T≤5),表示测试组数。
对于每组测试,输入 Q + 3 Q+3 Q+3 行。
第一行包含一个整数 n ( 1 ≤ n ≤ 1 0 5 1 \le n \le 10^5 1≤n≤105),表示排列 p p p 的长度。
第二行包含 n n n 个整数 p 1 , p 2 , … , p n p_{1},p_{2},\ldots,p_{n} p1,p2,…,pn ( 1 ≤ p i ≤ n 1 \le p_{i} \le n 1≤pi≤n),表示排列 p p p 的各个元素。
第三行包含一个整数 Q Q Q ( 1 ≤ Q ≤ 1 0 5 1 \le Q \le 10^5 1≤Q≤105),表示询问次数。
接下来 Q Q Q 行每行包含两个整数 L L L 和 R R R ( 1 ≤ L ≤ R ≤ n 1 \le L \le R \le n 1≤L≤R≤n),表示查询区间。
Output
对于每次查询输出对应答案。
Solution
首先只考虑如何算出 [ 1 , n ] [1,n] [1,n] 的对数,
由于 p i p_i pi 是一个排列,所以 p i + p j p_i+p_j pi+pj 最大值为 2 n − 1 2n-1 2n−1 ,那么对于每个 p j p_j pj 我们可以 O ( n ) O(\sqrt{n}) O(n) 枚举出 1 ≤ i < j 1 \le i<j 1≤i<j 使 p i + p j p_i+p_j pi+pj 为平方数的个数。
for (int j = 2; j * j < 2 * n; j++) {
int x = j * j - a[i];
if (x >= 1 && x <= n && p[x] < i)sum[id[p[x]]]++, c[p[x]]++;
}
此时我们再加入多个区间查询,可以发现这就是一个二维数点的问题,那么我们就可以进行离线分块处理,将每次查询的 L , R L,R L,R 按 R R R 为关键词存下。
for (int i = 0; i < m; i++) {
int l, r;
cin >> l >> r;
q[r].emplace_back(l, i);
}
在 i ∈ [ 1 , n ] i \in [1,n] i∈[1,n] 顺序做的同时当存在 i = = R i==R i==R 时利用分块暴力跑 [ L , R ] [L,R] [L,R] ,最终时间复杂度在 O ( n n ) O(n \sqrt{n}) O(nn)。
for (int i = 1; i <= n; i++) {
for (int j = 2; j * j < 2 * n; j++) {
int x = j * j - a[i];
if (x >= 1 && x <= n && p[x] < i)sum[id[p[x]]]++, c[p[x]]++;
}
for (auto [l, j] : q[i]) {
for (int k = l; k <= min(n, id[l]*len); k++)ans[j] += c[k];
for (int k = id[l] + 1; k <= decn; k++)ans[j] += sum[k];
}
}
Code
#include <bits/stdc++.h>
#define endl '\n'
using namespace std;
void solve() {
int n;
cin >> n;
int decn = 0, len = sqrt(n);
vector<int>a(n + 1), p(n + 1);
vector<int>id(n + 1), c(n + 1), sum(n + 1);
vector<pair<int, int>>q[n + 1];
for (int i = 1; i <= n; i++)cin >> a[i], p[a[i]] = i;
for (int l = 1, r; l <= n; l = r + 1) {
++decn;
r = min(l + len - 1, n);
for (int i = l; i <= r; i++)id[i] = decn;
}
int m;
cin >> m;
vector<int>ans(m);
for (int i = 0; i < m; i++) {
int l, r;
cin >> l >> r;
q[r].emplace_back(l, i);
}
for (int i = 1; i <= n; i++) {
for (int j = 2; j * j < 2 * n; j++) {
int x = j * j - a[i];
if (x >= 1 && x <= n && p[x] < i)sum[id[p[x]]]++, c[p[x]]++;
}
for (auto [l, j] : q[i]) {
for (int k = l; k <= min(n, id[l]*len); k++)ans[j] += c[k];
for (int k = id[l] + 1; k <= decn; k++)ans[j] += sum[k];
}
}
for (auto x : ans)cout << x << endl;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int t;
cin >> t;
while (t--)solve();
}