题目链接:https://codeforces.com/contest/1324/problem/D
题目描述
有两个长度为 n 的数组 A, B。问有多少对 (i,j) 满足 i < j 且 A[i]+A[j] > B[i]+B[j]。
数据范围
2 <= n <= 2*10^5,对于数组的每个元素 x 有 x ∈[1, 10^9]。
样例
n = 5
A = [4,8,2,6,2]
B = [4,5,4,1,3]
答案为 7, 分别为(1,2) (1,4) (2,3) (2,4) (2,5) (3,4) (4,5)
题目类型:排序,双指针,区间查询
解题思路
由公式
A
[
i
]
+
A
[
j
]
>
B
[
i
]
+
B
[
j
]
A[i]+A[j] \gt B[i]+B[j]
A[i]+A[j]>B[i]+B[j] 得
A
[
i
]
−
B
[
i
]
>
B
[
j
]
−
A
[
j
]
A[i]-B[i] \gt B[j]-A[j]
A[i]−B[i]>B[j]−A[j] 。
构造两个数组:
- D a b [ i ] = A [ i ] − B [ i ] D_{ab}[i] = A[i] - B[i] Dab[i]=A[i]−B[i]
- D b a [ i ] = B [ i ] − A [ i ] D_{ba}[i] = B[i] - A[i] Dba[i]=B[i]−A[i]
则该题答案等价于 满足条件
D
a
b
[
i
]
>
D
a
b
[
j
]
D_{ab}[i] \gt D_{ab}[j]
Dab[i]>Dab[j] 且
i
<
j
i \lt j
i<j 的
(
i
,
j
)
(i,j)
(i,j)的数量。
为了便于统计,可将
D
a
b
,
D
b
a
D_{ab},D_{ba}
Dab,Dba 排序,排序时需记录位置信息。
设数组
P
o
s
a
,
P
o
s
b
Pos_a,Pos_b
Posa,Posb分别记录
D
a
b
,
D
b
a
D_{ab},D_{ba}
Dab,Dba 的元素在排序前的位置。
设
a
n
w
anw
anw 为最终答案,初始为 0 。
使用双指针
p
,
q
p,q
p,q 遍历排序后的
D
a
b
,
D
b
a
D_{ab},D_{ba}
Dab,Dba,
p
p
p 每增加一,
q
q
q 应增加至满足
D
a
b
[
p
]
>
D
a
b
[
q
]
D_{ab}[p]\gt D_{ab}[q]
Dab[p]>Dab[q] 的最大值。更新完
p
,
q
p,q
p,q 后统计满足
P
o
s
b
[
j
]
>
P
o
s
a
[
p
]
,
j
∈
[
1
,
q
]
Pos_b[j] > Pos_a[p], j ∈ [1, q]
Posb[j]>Posa[p],j∈[1,q] 的
j
j
j 的数量并累加到
a
n
w
anw
anw 中。此处可借助线段树,树状数组等区间查询算法完成。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200001;
int A[MAXN], B[MAXN], st[MAXN*4];
struct Diff {
int pos;
int diff;
bool operator < (const Diff &r) const {
return this->diff < r.diff;
}
} diffAB[MAXN], diffBA[MAXN];
void update(int *st, int root, int L, int R, int goal) {
st[root]++;
if(L == R) {
return;
}
int mid = (L+R)>>1;
if(goal <= mid) {
update(st, root<<1, L, mid, goal);
} else {
update(st, root<<1|1, mid+1, R, goal);
}
}
int query(int *st, int root, int L, int R, int range) {
if (range <= 0) {
return 0;
}
if (R == range) {
return st[root];
}
int mid = (L+R)>>1;
if (range <= mid) {
return query(st, root<<1, L, mid, range);
}
return st[root<<1] + query(st, root<<1|1, mid+1, R, range);
}
int main() {
cin.sync_with_stdio(false);
int n;
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> A[i];
}
for(int i = 1; i <= n; i++) {
cin >> B[i];
}
for(int i = 1; i <= n; i++) {
diffAB[i].pos = i;
diffAB[i].diff = A[i] - B[i];
diffBA[i].pos = i;
diffBA[i].diff = B[i] - A[i];
}
sort(diffAB+1, diffAB+n+1);
sort(diffBA+1, diffBA+n+1);
int64_t anw = 0;
for(int i = 1, j = 1; i <= n; i++) {
while(j <= n && diffBA[j].diff < diffAB[i].diff) {
update(st, 1, 1, n, diffBA[j].pos);
j++;
}
anw += query(st, 1, 1, n, diffAB[i].pos-1);
}
cout << anw << endl;
return 0;
}
扫描图片关注 HelloNebula 获取更多有趣题目~