题目地址:
https://www.acwing.com/problem/content/257/
给定长度为 N N N的整数序列 A A A,下标为 1 ∼ N 1∼N 1∼N。现在要执行 M M M次操作,其中第 i i i次操作为给出三个整数 l i , r i , k i l_i,r_i,k_i li,ri,ki,求 A [ l i ] , A [ l i + 1 ] , … , A [ r i ] A[l_i],A[l_{i+1}],…,A[r_i] A[li],A[li+1],…,A[ri](即 A A A的下标区间 [ l i , r i ] [l_i,r_i] [li,ri])中第 k i k_i ki小的数是多少。
输入格式:
第一行包含两个整数
N
N
N和
M
M
M。第二行包含
N
N
N个整数,表示整数序列
A
A
A。接下来
M
M
M行,每行包含三个整数
l
i
,
r
i
,
k
i
l_i,r_i,k_i
li,ri,ki,用以描述第
i
i
i次操作。
输出格式:
对于每次操作输出一个结果,表示在该次操作中,第
k
k
k小的数的数值。每个结果占一行。
数据范围:
N
≤
1
0
5
,
M
≤
1
0
4
,
∣
A
[
i
]
∣
≤
1
0
9
N≤10^5,M≤10^4,|A[i]|≤10^9
N≤105,M≤104,∣A[i]∣≤109
思路是可持久化线段树,又称主席树。
先介绍可持久化线段树的结构。可持久化线段树和可持久化Trie的思想是一致的,也是动态开点,即未被修改的路径重用上一个版本的节点,否则开新的点。参考https://blog.csdn.net/qq_46105170/article/details/119029015。在主席树里我们主要考虑可以做当前版本单点修改和任一版本区间查询的可持久化线段树。如下图所示:
例如左边绿色部分我们已经建立了一个线段树,版本号是
0
0
0,现在我们要做单点修改,该单点修改要经过
1
,
3
,
6
,
12
1,3,6,12
1,3,6,12这
4
4
4个节点,那么主席树会将未被修改的路径的所有节点直接拿过来复用,而对做了修改的路径开新点。由于
1
,
3
,
6
,
12
1,3,6,12
1,3,6,12这
4
4
4个点会发生改变,那么要新开一个树根
1
′
1'
1′,其左孩子没有修改,则复用版本
0
0
0的,对于将要被修改的
3
3
3号点,开辟一个新点,然后走下去,接着
7
7
7号点不会被修改,则复用版本
0
0
0的
7
7
7号点,再开辟新的
6
′
6'
6′点,然后走下去,这样以此类推。我们发现,每次新的版本最多只会new出大概
log
n
\log n
logn这么多节点,是很省空间的。但是由于主席树需要不停地开新节点,所以用完全二叉树的方式来存就没有必要了,因为除了一开始建的树以外,之后的版本里节点的左右孩子的下标都是不固定的。所以我们采用存指针的方式(即存两个孩子在数组中的下标),而对于当前节点维护的区间范围,可以作为参数在调用函数的时候传进来。
那么本题应该怎么做呢。首先要查找的是某个下标区间里的第 k k k小数。回想一下平衡树里查找第 k k k小数的过程,如果每个节点维护子树节点数,并且维护当前key出现次数,那么就可以通过类似折半查找的做法把第 k k k小的数求出来。本题也类似,可以建立一个线段树,每个节点维护的是 A A A在该范围内的数的个数。例如线段树里维护 [ 0 , 2 ] [0,2] [0,2]的区间的节点,记录的就是 A A A中取值在 [ 0 , 2 ] [0,2] [0,2]有多少个数。一开始版本 0 0 0的线段树相当于在维护空数组,接着将 A [ i ] A[i] A[i]逐次插入,形成 N N N个版本。第 i i i个版本维护的就是 A [ 1 ∼ i ] A[1\sim i] A[1∼i]在各个区间里取值的数的个数。如果要查询在 A [ 1 : r ] A[1:r] A[1:r]内的第 k k k小的数,就可以查看第 r r r个版本的线段树,然后每次看一下左孩子维护的区间里有多少个数,如果有 c c c个,并且 k ≤ c k\le c k≤c,那么就说明第 k k k小的数在左半区间,则去左半区间找第 k k k小的数;否则说明第 k k k小的数在右半区间,则去右半区间找第 k − c k-c k−c小的数。这和平衡树里求第 k k k小的数的过程完全一样,也是在二分答案。但是现在是要查询 A [ l : r ] A[l:r] A[l:r]内第 k k k小。这可以利用前缀和思想,我们考虑第 r r r个版本和第 l − 1 l-1 l−1个版本,两个版本的线段树的差,比如说比较两个版本维护区间 [ a , b ] [a,b] [a,b]的节点里记录的 c c c值,分别叫 c r c_r cr和 c l − 1 c_{l-1} cl−1,那么 c r − c l − 1 c_r-c_{l-1} cr−cl−1其实就是 A [ 1 : r ] A[1:r] A[1:r]相比于 A [ 1 : l − 1 ] A[1:l-1] A[1:l−1]而言,在 [ a , b ] [a,b] [a,b]里的数字个数多了多少个,那其实就是 A [ l : r ] A[l:r] A[l:r]里有多少个数在 [ a , b ] [a,b] [a,b]里。有了这个信息,就可以二分答案来解决了。本题由于 A [ i ] A[i] A[i]的取值范围过大,按照这个取值范围建线段树太费空间,需要做离散化,即将 A A A映射到 0 ∼ N − 1 0\sim N-1 0∼N−1,然后用线段树维护 0 ∼ N − 1 0\sim N-1 0∼N−1这个区间即可。求完之后再映射回来即可。
接下来考虑主席树的几个基本操作怎么实现:
1、建树。建树只建出一个树的框架,并不真的去把数据
A
A
A都填进去,真正填
A
A
A的时候会每填一个数,就开一个新的版本。建树完毕之后可以视为这是版本
0
0
0。代码如下:
// 这里的l和r代表线段树维护的范围,返回的是建好的树的树根下标
int build(int l, int r) {
// 征用新节点。idx记录的是用到了哪个节点。和Trie一样,节点的空间会预先开辟好
int p = ++idx;
// 如果递归到叶子了就返回节点下标
if (l == r) return p;
int mid = l + (r - l >> 1);
// 递归建立左右子树,并接到当前节点上
tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
// 返回当前节点的下标
return p;
}
2、插入。这里的插入等价于普通线段树的单点修改,例如说插入 x x x,对应的就是被维护数组的下标 x x x的地方增加 1 1 1(这里的 x x x是离散化后的),只不过主席树会在每次插入的时候新开一个版本。可以写成递归版本:
// p是上一个版本的树根下标,[l, r]是p维护的区间,x是要插入的数,返回的是新版本的树根
int insert(int p, int l, int r, int x) {
// 先建出新版本树根,复制原树根信息
int q = ++idx;
tr[q] = tr[p];
tr[q].cnt++;
// 走到叶子了就直接返回节点下标
if (l == r) {
return q;
}
int mid = l + (r - l >> 1);
// 否则看x在哪个半区间,递归建立新路径
if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x);
else tr[q].r = insert(tr[p].r, mid + 1, r, x);
return q;
}
也可以写成非递归版本:
int insert(int p, int l, int r, int x) {
// 建立出新版本树根,然后向下建立新路径节点
int root = ++idx;
int q = root;
while (l < r) {
// 先拷贝之前版本的节点,并计数加1
tr[q] = tr[p];
tr[q].cnt++;
int mid = l + (r - l >> 1);
// 如果x在左半区间,则新开辟左孩子,然后p和q同时下移,并收缩右端点
if (x <= mid) {
tr[q].l = ++idx;
q = tr[q].l, p = tr[p].l;
r = mid;
} else {
// 否则x在右半区间,则新开辟右孩子,然后p和q同时下移,并收缩左端点
tr[q].r = ++idx;
q = tr[q].r, p = tr[p].r;
l = mid + 1;
}
}
// 最后走到了叶子节点,叶子节点的左右孩子都是0,不需要动,只需计数加1
tr[q].cnt++;
return root;
}
递归版本可能更好理解一些。
3、在
A
[
l
:
r
]
A[l:r]
A[l:r]里查询第
k
k
k小。这需要在第
l
−
1
l-1
l−1版本和第
r
r
r版本同时向下折半查找,每次都计算左半区间的元素个数
c
c
c,然后和
k
k
k比较,如果左半区间元素个数大于等于
k
k
k,则说明答案在左半区间,去左半区间找第
k
k
k小;否则说明在右半区间。去右半区间找第
k
−
c
k-c
k−c小。代码如下:
int query(int q, int p, int l, int r, int k) {
if (l == r) return l;
int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt;
int mid = l + (r - l >> 1);
if (k <= cnt) return query(tr[q].l, tr[p].l, l, mid, k);
else return query(tr[q].r, tr[p].r, mid + 1, r, k - cnt);
}
总代码如下:
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 1e5 + 10;
int n, m;
int a[N];
vector<int> nums;
struct Node {
int l, r;
int cnt;
} tr[(N << 2) + N * 17];
int root[N], idx;
int find(int x) {
return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}
int build(int l, int r) {
int p = ++idx;
if (l == r) return p;
int mid = l + (r - l >> 1);
tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
return p;
}
int insert(int p, int l, int r, int x) {
int q = ++idx;
tr[q] = tr[p];
tr[q].cnt++;
if (l == r) {
return q;
}
int mid = l + (r - l >> 1);
if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x);
else tr[q].r = insert(tr[p].r, mid + 1, r, x);
return q;
}
int query(int q, int p, int l, int r, int k) {
if (l == r) return l;
int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt;
int mid = l + (r - l >> 1);
if (k <= cnt) return query(tr[q].l, tr[p].l, l, mid, k);
else return query(tr[q].r, tr[p].r, mid + 1, r, k - cnt);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
nums.push_back(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
root[0] = build(0, nums.size() - 1);
for (int i = 1; i <= n; i++)
root[i] = insert(root[i - 1], 0, nums.size() - 1, find(a[i]));
while (m--) {
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", nums[query(root[r], root[l - 1], 0, nums.size() - 1, k)]);
}
return 0;
}
时间复杂度 O ( n log n + m log n ) O(n\log n+m\log n) O(nlogn+mlogn),即每次插入和查询时间复杂度都是 O ( log n ) O(\log n) O(logn),空间 O ( n + n log n ) O(n+n\log n) O(n+nlogn),每次插入都要新开 O ( log n ) O(\log n) O(logn)个节点。