线段树处理逆序对问题
逆序对问题可以由归并排序递归地处理,时间复杂度是 O(nlog2n) O ( n log 2 n ) 。但是在这里,使用线段树来加深理解。个人认为,线段树的方法和归并的方法根本区别在于,前者是一种在线算法,后者是一种离线算法(这只是个人的看法而已。。。)。在归并方法中,我们是在排序的过程中处理累计个数的,由于排序序列是已知的,我们是在已知序列的情况下进行统计,因此看成是离线的。而在线段树方法中,我们是在枚举的同时,立刻累加上逆序对信息,因此看成是在线的。
在最好理解的暴力方法中,真正耗时的步骤是查询过程,不断重复查找使得时间复杂度升高到 O(n2) O ( n 2 ) 。暴力查询的过程本身是耗时的,而在查询的过程中,我们只是把计算的焦点放在了当前查询的数据上,忽略掉了很多遍历过程中的信息。比如: 2– 3 4 5 1– 9 2 _ 3 4 5 1 _ 9 ,第一次从2开始搜索到1,找到了一个逆序对,该遍历过程肯定经过3,而且<3,1>也是逆序对,但是这个信息被完全忽略掉了。我们的优化点在于保留已经遍历过的信息(这有些像KMP算法的思想),同时减少遍历查询的时间。
首先,要摒弃传统的从给出的数列头遍历到数列尾查找的思想,而是先提取并保留有效信息。逆序对本身不关心数的绝对大小,只关心数的相对大小(可以理解成数在数轴上的相对位置)。那么,对数据离散化处理:把数列进行一次排序,之后去重处理,最后得到一个数据相对位置的数组pos[]
。 这个数组有效的长度是原来输入数列的长度,每个元素的值是第i
大的数在原来数列的位置。
之后,把这N个数据从头到尾进行枚举,累加计算线段树中第i+1
到第N
项的和。这是本算法中最难理解的部分,智商低,卡了一下午 :( 。 这里和暴力法正好有一个思维上的逆转,只要数据一出现,那么我们就立刻知道它的相对位置,因此只要知道了此时它的线段树后面位置上比它大的数,也就是在原始序列前的比它大的数,就立刻把总数加1。这也是在线
这一词的来历。
有一点需要注意:如果查询的数据值的本身波动范围不是太大,那么就没有必要进行离散化处理,处理过程的本身是相当耗时的。在这里是假设数据范围很大,以至于无法为线段树开辟这么大的内存空间。
代码:
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MAXN = 35367; // 可能出现的最多能的数据个数
int times[MAXN] = {0}, MAX = 0; // 位置 次数 离散化后最大的位置
LL num[MAXN], b[MAXN]; // num是原始输入,b是num的副本
int pos[MAXN]; // 计算相对位置
struct Node {
int l, r, sum;
struct Node *lc, *rc;
Node(): l(0), r(0), sum(0), lc(nullptr), rc(nullptr) {}
};
void build(Node* &cur, int l, int r) { // 建树
cur = new Node;
cur->l = l;
cur->r = r;
if(l + 1 < r) {
build(cur->lc, l, (l + r) >> 1);
build(cur->rc, (l + r) >> 1, r);
}
}
void change(Node* cur, int x) { // 加入数据后更改信息
if(cur->l + 1 == cur->r) {
++cur->sum; // 数的个数增加一个
} else {
if(x < (cur->l + cur->r) / 2) {
change(cur->lc, x);
}
if(x >= (cur->l + cur->r) / 2) {
change(cur->rc, x);
}
cur->sum = cur->lc->sum + cur->rc->sum;
}
}
int query(Node* &cur, int l, int r) { // 查询算法
if(l <= cur->l && cur->r <= r) {
return cur->sum;
} else {
int ans = 0;
if(l < (cur->l + cur->r) / 2) {
ans += query(cur->lc, l, r);
}
if(r > (cur->l + cur->r) / 2) {
ans += query(cur->rc, l, r);
}
return ans;
}
}
int main() {
int N;
cin >> N;
for(int i = 0; i < N; ++i) {
cin >> num[i];
b[i] = num[i];
}
// 数据离散化处理
sort(num, num + N);
int len = distance(num, unique(num, num + N)); // 实际不重复元素的个数
for(int i = 0; i < N; ++i) {
pos[i] = lower_bound(num, num + len, b[i]) - num + 1; // 计算相对位置
}
// 线段树操作
Node* root{nullptr};
build(root, 1, len + 1); // 建树
int k = 0;
for(int i = 0; i < N; ++i) {
change(root, pos[i]);
if(pos[i] == len) { // 如果最大的数在最后,就没有比较的意义了。查询反而会使线段树结构出错
continue;
}
k += query(root, pos[i] + 1, len + 1); // 插入后,立刻进行累加,在线的
}
cout << k << endl;
return 0;
}
逆序对的拓展:
问题:在数列中只要有
ai<aj>ak
a
i
<
a
j
>
a
k
,且
i<j<k
i
<
j
<
k
,那么就称这是一个“好的”组合,给出任意个这个组合,求解“好的”组合的个数。
思路与逆序对一样,建树统计的代码也和逆序对的一样,区别在于统计方法上。
在代码中,每个位置的l[i]
和r[i]
的循环统计方向是正好相反的,这是由信息的时效性决定的。第i
的数据决定的“好的”序对个数m=l[i]*r[i]
(排列组合) 。很明显
ai<aj>ak
a
i
<
a
j
>
a
k
中,当前已知的是
aj
a
j
,我们需要统计
aj
a
j
两侧的数据。对于l[i]
来说, 需要知道i
左侧数据的情况,因此需要正向循环;r[i]
需要先知道i
右侧的情况,因此要反向循环。线段树的在线思想也在这里体现。
代码:
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MAXN = 10000;
LL num[MAXN] = {0}, cp[MAXN] = {0}; // num存储数据、cp是num的副本
// pos存储位置信息、l[i]存储第i个元素左侧小于它的个数、r[i]存储第i个元素右侧大于它的元素
int pos[MAXN] = {0}, l[MAXN] = {0}, r[MAXN] = {0};
struct Node {
int l, r, sum;
struct Node *lc, *rc;
Node(): l(0), r(0), sum(0), lc(nullptr), rc(nullptr) {}
};
void build(Node* &cur, int l, int r) { // 建树
cur = new Node;
cur->l = l;
cur->r = r;
if(l + 1 < r) {
build(cur->lc, l, (l + r) >> 1);
build(cur->rc, (l + r) >> 1, r);
}
}
void change(Node* cur, int x) { // 增加个数
if(cur->l + 1 == cur->r) {
++cur->sum;
} else {
if(x < (cur->l + cur->r) >> 1) {
change(cur->lc, x);
}
if(x >= (cur->l + cur->r) >> 1) {
change(cur->rc, x);
}
cur->sum = cur->lc->sum + cur->rc->sum;
}
}
int query(Node* cur, int l, int r) { // 区间查询
if(l <= cur->l && cur->r <= r) {
return cur->sum;
} else {
int ans = 0;
if(l < (cur->l + cur->r) >> 1) {
ans += query(cur->lc, l, r);
}
if(r > (cur->l + cur->r) >> 1) {
ans += query(cur->rc, l, r);
}
return ans;
}
}
int main() {
int N;
cin >> N;
for(int i = 0; i < N; ++i) {
cin >> num[i];
cp[i] = num[i];
}
// 离散化处理
sort(num, num + N);
int len = distance(num, unique(num, num + N));
cout << "len=" << len << endl; // 不重复数据的个数
for(int i = 0; i < N; ++i) {
pos[i] = lower_bound(num, num + len, cp[i]) - num + 1; // 确定相对位置,从1开始
}
Node* root{nullptr};
build(root, 1, len + 1);
// 处理逆序的,注意是反向循环的!!!!!!!
for(int i = N - 1; i >= 0; --i) {
change(root, pos[i]);
if(pos[i] + 1 == len + 1) {
continue;
}
r[i] += query(root, pos[i] + 1, len + 1);
}
Node* root1{nullptr};
build(root1, 1, len + 1);
// 处理正序的
for(int i = 0; i < N; ++i) {
change(root1, pos[i]);
if(pos[i] <= 1) {
continue;
}
l[i] += query(root1, 1, pos[i]);
}
cout << "l[]:";
for(int i = 0; i < N; ++i) {
cout << l[i] << " ";
}
cout << endl << "r[]:";
for(int i = 0; i < N; ++i) {
cout << r[i] << " ";
}
cout << endl;
int k = 0;
for(int i = 0; i < N; ++i) { // 这里总共是N个数
k += l[pos[i]] * r[pos[i]];
}
cout << "res=" << k << endl;
return 0;
}