题目链接: BZOJ - 3110
题目分析
这道题是一道树套树的典型题目,我们使用线段树套线段树,一层是区间线段树,一层是权值线段树。一般的思路是外层用区间线段树,内层用权值线段树,但是这样貌似会很难写。多数题解都使用了外层权值线段树,内层区间线段树,于是我就这样写了。每次插入会在 logn 棵线段树中一共建 log^2(n) 个结点,所以空间应该开到 O(nlog^2(n)) 。由于这道题查询的是区间第 k 大,所以我们存在线段树中的数值是输入数值的相反数(再加上 n 使其为正数),这样查第 k 小就可以了。在查询区间第 k 大值的时候,我们用类似二分的方法,一层一层地逼近答案。
写代码的时候出现的错误:在每一棵区间线段树中修改数值的时候,应该调用的是像 Insert(Lc[x], 1, n, l, r) 这样子,但我经常写成 Insert(x << 1, s, t, l, r) 之类的。注意!
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
using namespace std;
const int MaxN = 100000 + 5, MaxM = 100000 * 16 * 16 + 5;
int n, m, f, a, b, c, Index, Ans;
int Root[MaxN * 4], Lc[MaxM], Rc[MaxM], Sum[MaxM], Lazy[MaxM];
inline int gmin(int a, int b) {
return a < b ? a : b;
}
inline int gmax(int a, int b) {
return a > b ? a : b;
}
int Get(int x, int s, int t, int l, int r) {
if (l <= s && r >= t) return Sum[x];
int p = 0, q = 0, m = (s + t) >> 1;
if (l <= m) p = Get(Lc[x], s, m, l, r);
if (r >= m + 1) q = Get(Rc[x], m + 1, t, l, r);
return (p + q + Lazy[x] * (gmin(t, r) - gmax(s, l) + 1));
}
int GetKth(int l, int r, int k) {
int s = 1, t = n * 2, m, x = 1, Temp;
while (s != t) {
m = (s + t) >> 1;
if ((Temp = Get(Root[x << 1], 1, n, l, r)) >= k) {
t = m; x = x << 1;
}
else {
s = m + 1; x = x << 1 | 1; k -= Temp;
}
}
return s;
}
void Insert(int &x, int s, int t, int l, int r) {
if (x == 0) x = ++Index;
if (l <= s && r >= t) {
Sum[x] += t - s + 1;
++Lazy[x];
return;
}
int m = (s + t) >> 1;
if (l <= m) Insert(Lc[x], s, m, l, r);
if (r >= m + 1) Insert(Rc[x], m + 1, t, l, r);
Sum[x] = Sum[Lc[x]] + Sum[Rc[x]] + Lazy[x] * (t - s + 1);
}
void Add(int l, int r, int Num) {
int s = 1, t = n * 2, m, x = 1;
while (s != t) {
Insert(Root[x], 1, n, l, r);
m = (s + t) >> 1;
if (Num <= m) {
t = m;
x = x << 1;
}
else {
s = m + 1;
x = x << 1 | 1;
}
}
Insert(Root[x], 1, n, l, r);
}
int main()
{
scanf("%d%d", &n, &m);
Index = 0;
for (int i = 1; i <= m; ++i) {
scanf("%d%d%d%d", &f, &a, &b, &c);
if (f == 1) {
c = -c + n + 1;
Add(a, b, c);
}
else {
Ans = GetKth(a, b, c);
Ans = -Ans + n + 1;
printf("%d\n", Ans);
}
}
return 0;
}