刘汝佳的数据结构专场题目, 题目就是给定一个1~n的排列求删除某个数后剩下的逆序对的个数, 我
一开始的做法是树状数组套Treap后来T了, 又翻了白书, 白书上的解法是用到了一个静态的BST,因为这
题的特殊性, BST只需要有删除动能, 所以再用Treap就显得大才小用了, 而且用静态的常数也小了很多,
套Treap总的复杂度是(n + q) * log(n) * log(n) 而 静态BST的复杂度是q * log(n) * log(n)。。。
书上还提到了一种用基于分治的树状数组解法, 后来才知道其实就是cdq 分治, 最近在学cdq分治, 这
题挺适合练习的, 首先对于每次删除操作我们可以预处理出删除某个数最多会减少多少个逆序对(即左边比该
数字大的数的个数再加上右边比该数字小的数的个数), 但删除某个数可能使得在它之后进行的删除操作的减
少的逆序对个数减少(有点绕。。。), 考虑在什么情况下一个删除操作会对另一个产生影响不难得到一个三
维偏序关系, 所以可以以删除数字的数值大小进行排序, 然后再分治, 表达的不是很清楚, 详见代码。。。
解法1:
#include <stdio.h>
#include <string.h>
#include <vector>
#include <algorithm>
#include <iostream>
using namespace std;
const int N = 200005;
const int M = N * 20;
typedef long long LL;
#define DB cout << "Yes" << endl;
#define DB2 cout << "No" << endl;
inline int readint() {
char c = getchar();
while (!isdigit(c)) c = getchar();
int x = 0;
while (isdigit(c)) {
x = x * 10 + c - '0';
c = getchar();
}
return x;
}
int buf[10];
inline void writeint(int x) {
int p = 0;
if (x == 0) p++;
else while (x) {
buf[p++] = x % 10;
x /= 10;
}
for (int j = p - 1; j >= 0; j--)
putchar('0' + buf[j]);
}
int arr[N], pos[N], tmp[N];
LL res;
inline int lowbit(int t) {
return t & (-t);
}
struct BST {
int L[M], R[M], sz[M], V[M], T[N];
bool is[M];
int n, root[N];
int id;
LL key;
void init(int n) {
id = 1;
this->n = n;
int l, r;
for (int i = 1; i <= n; i++) {
l = i - lowbit(i) + 1, r = i;
pos[arr[i]] = i;
for (int j = l; j <= r; j++)
T[j] = arr[j];
sort(T + l, T + r + 1);
build(root[i], l, r);
}
}
inline void push_up(int rt) {
sz[rt] = sz[L[rt]] + sz[R[rt]] + is[rt];
}
void build(int& rt, int l, int r) {
if (l > r) return;
rt = id++;
is[rt] = 1;
int mid = l + r >> 1;
int v = T[mid];
V[rt] = v;
L[rt] = R[rt] = 0;
build(L[rt], l, mid - 1);
build(R[rt], mid + 1, r);
push_up(rt);
}
void remove(int rt, int v) {
if (rt == 0) return;
if (v == V[rt]) {
is[rt] = 0;
sz[rt]--;
}
else if (v < V[rt]) {
remove(L[rt], v);
push_up(rt);
}
else {
remove(R[rt], v);
push_up(rt);
}
}
int lower(int rt, int v) {
if (rt == 0 || !sz[rt]) return 0;
if (v == V[rt]) {
return sz[L[rt]];
}
else if (v < V[rt]) {
return lower(L[rt], v);
}
else {
int tmp = sz[rt] - sz[R[rt]];
return tmp + lower(R[rt], v);
}
}
int upper(int rt, int v) {
if (rt == 0) return 0;
if (v == V[rt])
return sz[R[rt]];
else if (v > V[rt])
return upper(R[rt], v);
else {
int tmp = sz[rt] - sz[L[rt]];
return tmp + upper(L[rt], v);
}
}
void gao(int v) {
int p = pos[v];
int t = p;
while (p <= n) {
remove(root[p], v);
p += lowbit(p);
}
printf("%lld\n", key);
p = t - 1;
LL cnt = 0;
while (p) {
cnt += upper(root[p], v);
cnt -= lower(root[p], v);
p -= lowbit(p);
}
p = n;
while (p) {
cnt += lower(root[p], v);
p -= lowbit(p);
}
key -= cnt;
}
}T;
void msort(int* A, int l, int r) {
if (l < r) {
int mid = l + r >> 1;
msort(A, l, mid);
msort(A, mid + 1, r);
int p = l, q = mid + 1, id = l;
while (p <= mid || q <= r) {
if (q > r || p <= mid && A[p] <= A[q]) tmp[id++] = A[p++];
else {
res += mid - p + 1;
tmp[id++] = A[q++];
}
}
for (int i = l; i <= r; i++)
arr[i] = tmp[i];
}
}
int main() {
int n, m, v;
while (~scanf("%d%d", &n, &m)) {
for (int i = 1; i <= n; i++)
arr[i] = readint();
T.init(n);
res = 0;
msort(arr, 1, n);
T.key = res;
for (int i = 0; i < m; i++) {
scanf("%d", &v);
T.gao(v);
}
}
return 0;
}
解法2:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <algorithm>
using namespace std;
const int N = 200005;
typedef long long LL;
int A[N], pos[N], ans[N], del[N], op[N], t1[N], to[N], key[N];
LL c[N];
int n;
void add(int p) {
while (p <= n) {
c[p] += 1;
p += p & -p;
}
}
void sub(int p) {
while (p <= n) {
c[p] -= 1;
p += p & -p;
}
}
int sum(int m) {
int res = 0;
while (m) {
res += c[m];
m -= m & -m;
}
return res;
}
void solve(int l, int r, bool flag) {
if (l >= r) return;
int mid = l + r >> 1;
int p = l, q = mid + 1;
for (int i = l; i <= r; i++) t1[op[i] <= mid ? p++ : q++] = op[i];
copy(t1 + l, t1 + r + 1, op + l);
/*int c = 1;
*/
int cur = l;
for (int i = mid + 1; i <= r; i++) {
int x = op[i];
if (!flag) {
while (cur <= mid && del[op[cur]] < del[x]) {
add(pos[del[op[cur]]]);
cur++;
}
ans[x] -= sum(n) - sum(pos[del[x]]);
//cout << sum(n) << ' ' << sum(pos[del[x]]) << endl;
}
else {
while (cur <= mid && del[op[cur]] > del[x]) {
add(pos[del[op[cur]]]);
cur++;
}
ans[x] -= sum(pos[del[x]]);
}
}
for (int i = l; i < cur; i++) sub(pos[del[op[i]]]);
solve(l, mid, flag), solve(mid + 1, r, flag);
}
bool cmp(const int& a, const int& b) {
return del[a] < del[b];
}
bool cmp2(const int& a, const int& b) {
return del[a] > del[b];
}
LL res;
void msort(int l, int r) {
if (l >= r) return;
int mid = l + r >> 1;
int p = l, q = mid + 1, id = l;
msort(l, mid);
msort(mid + 1, r);
while (p <= mid || q <= r) {
if (q > r || (p <= mid && A[p] < A[q])) t1[id++] = A[p++];
else {
res += mid - p + 1;
t1[id++] = A[q++];
}
}
copy(t1 + l, t1 + r + 1, A + l);
}
int main() {
int m;
while (~scanf("%d%d", &n, &m)) {
for (int i = 1; i <= n; i++)
scanf("%d", A + i), pos[A[i]] = i, op[i] = i;
for (int i = 1; i <= m; i++)
scanf("%d", del + i);
memset(c, 0, sizeof(c));
for (int i = 1; i <= n; i++) {
int tmp = sum(A[i]);
key[i] = sum(n) - tmp;
add(A[i]);
key[i] += A[i] - 1 - tmp;
}
memset(c, 0, sizeof(c));
for (int i = 1; i <= m; i++)
ans[i] = key[pos[del[i]]];
res = 0;
msort(1, n);
sort(op + 1, op + 1 + m, cmp);
solve(1, m, 0);
sort(op + 1, op + 1 + m, cmp2);
solve(1, m, 1);
for (int i = 1; i <= m; i++) {
printf("%lld\n", res);
res -= ans[i];
}
}
return 0;
}