题目
题意:给定数组a和b,a相对位置不可动,b可以重新打乱顺序,插入到a数组中任意位置。求得到的新数组c,的最小逆序对。
官方题解
思路:要使逆序对最小,首先b数组必须为内部有序,所以需要对b数组进行预处理排序。插入到a数组的位置有n+1个。第i个位置表示插入在数
a
i
a_i
ai前,其中n+1表示插入在数
a
n
a_n
an后边。用
p
o
s
i
pos_i
posi表示
b
i
b_i
bi插入到
a
i
a_i
ai数组的新位置,由于b数组有序,所以有
p
o
s
1
<
=
p
o
s
2
<
=
,
,
,
<
=
p
o
s
m
pos_1<=pos_2<=,,,<=pos_m
pos1<=pos2<=,,,<=posm。由于这种有序性,我们可以递归二分,来限定每个
p
o
s
i
pos_i
posi可以插入的范围。
定义
p
o
s
i
pos_i
posi可以插入的范围为
[
p
l
,
p
r
]
[pl,pr]
[pl,pr],通过线性计数,求出
p
o
s
i
pos_i
posi的最佳插入位置(即最小逆序对数对应的位置)
p
m
pm
pm,那么有,对于
j
>
i
j>i
j>i的数
p
o
s
j
pos_j
posj,可以插入的范围为
[
p
m
,
p
r
]
[pm,pr]
[pm,pr];对于
j
<
i
j<i
j<i的数
p
o
s
j
pos_j
posj,可以插入的范围为
[
p
l
,
p
m
]
[pl,pm]
[pl,pm]。通过这种方式,我们可以逐步限定
p
o
s
j
pos_j
posj的插入取值,最终求出pos数组。
求出pos数组后,通过归并排序求出逆序对即可。
PS:注意逆序对的大小边界,以及 n + m n+m n+m的范围。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2000010;
#define ll long long
int n, m;
int a[maxn], b[maxn];
int pos[maxn], c[maxn];
int lnum[maxn], rnum[maxn];
//[l, r) for b; [pl, pr] for inserte position
void dfs(int l, int r, int pl, int pr)
{
if (l >= r) return;
int m = (l + r) / 2;
int pm = pl, val = b[m];
lnum[pl] = rnum[pr] = 0;
for (int i = pl + 1; i <= pr; ++i) {
lnum[i] = lnum[i-1] + (a[i-1] > val);
}
for (int i = pr - 1; i >= pl; --i) {
rnum[i] = rnum[i + 1] + (a[i] < val);
}
for (int i = pl; i <= pr; ++i) {
if ((lnum[i] + rnum[i]) <
(lnum[pm] + rnum[pm])) {
pm = i;
}
}
pos[m] = pm;
dfs(l, m, pl, pm);
dfs(m + 1, r, pm, pr);
}
ll merge_sort(int l, int r)
{
if (l + 1 >= r) return 0;
int m = (l + r) / 2;
ll res = 0;
res += merge_sort(l, m);
res += merge_sort(m, r);
for (int i = l; i < m; ++i) {
a[i] = c[i];
}
for (int i = m; i < r; ++i) {
b[i] = c[i];
}
int i = l, j = m, k = l;
while (i < m && j < r) {
if (a[i] <= b[j]) {
c[k++] = a[i];
++i;
} else {
c[k++] = b[j];
res += m - i;
++j;
}
}
while (i < m) {
c[k++] = a[i];
++i;
}
while (j < r) {
c[k++] = b[j];
++j;
}
return res;
}
ll solve()
{
dfs(1, m + 1, 1, n + 1);
// get final array c[]
int k = 0, j = 1;
for (int i = 1; i <= n; ++i) {
while (j <= m && pos[j] <= i) {
c[k++] = b[j];
++j;
}
c[k++] = a[i];
}
while (j <= m) {
c[k++] = b[j];
++j;
}
// merge_sort get number of inversions
return merge_sort(0, k);
}
int main() {
int t;
scanf("%d", &t);
while (t--) {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= m; ++i) {
scanf("%d", &b[i]);
}
sort(b + 1, b + m + 1);
printf("%lld\n", solve());
}
}