题目链接: 二逼平衡树
大致题意
读不懂就不要做这个题啦!!!
解题思路
思路一: 线段树套平衡树(既然题目都叫树套树了, 总得给个面子吧.)
首先我们观察五个操作, 我们发现他很像平衡树的题, 用平衡树维护序列有序即可. 但是由于每次查询是是对于区间的一次查询, 因此我们应当联想到采用线段树来维护询问区间, 然后再通过平衡树来解决问题. (这里稍微啰嗦几句, 一般的线段树套平衡树, 线段树的主要作用都是把查询的区间划分开. 因此线段树的内部是不维护信息的. 每次查询时, 我们通过线段树找到查询区间, 然后再找到该区间所对应的平衡树去求解即可.)
那么分析五个操作.
操作①: 相当于找有多少个数比 k k k小, 而对于询问区间内的每一棵平衡树而言, 答案是可加的. 因此只需要在线段树内找到询问区间, 然后把区间内每一棵平衡树的答案加和即可.
操作②: 我们发现这个操作不同于操作①的一点是, 操作②不具有可加性, 我们没法直接找到多棵平衡树排名为k的值, 除非把这些平衡树合并起来. 但是用脚想也知道, 这样一定会炸复杂度的.
但是我们惊奇的发现~~(指的通过看题解)~~, 其实操作②可以通过二分答案, 然后通过操作①来求解. 每次二分
m
i
d
mid
mid, 看看
m
i
d
mid
mid的排名和
k
k
k之间的关系.
本段话我感觉很重要, 能够帮助大家更好的理解二分.
这里注意一点: 二分的时候, 应当把二分条件假象为: 是否当前mid值的排名<=k, 然后求当前答案的最右边界. 因为可能树内会有重复的值, 如果认为二分的条件为 ≥ k \ge k ≥k时, 假设答案为x, 分别排名为 k − 1 k - 1 k−1和 k k k, 那么由于我们操作①求排名时, 求的是最靠前的这个排名, 这样会导致当二分到 x x x时, 会认为 x x x不符合要求, 最终会二分到 x + 1 x + 1 x+1, 而 x + 1 x + 1 x+1的排名是 k + 1 k + 1 k+1而不是 k k k.(也可以通过把二分条件改为 > k, 最终答案再减去1即可.)
操作③: 聊完了毁天灭地的操作②, 在看操作③就舒服多了, 这就是一个平衡树的单点修改. 方法: 找到要修改的节点, 删除它, 然后再插入新值即可.
操作④, ⑤: 这两个操作就是平衡树中找前驱和后继的基操. 如果找不到输出INF的话, 只需要再每棵树中插入两个哨兵节点, 使得其值分别为 ± I N F \pm INF ±INF.
树套树AC代码在这里➡️ 线段树套平衡树AC代码
这里稍微计算一下复杂度:
时间复杂度: 除操作②外, 都是先在线段树内找到平衡树, 然后询问平衡树. 因此时间复杂度为 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n), 操作②多一个二分, 为: O ( n l o g 2 n l o g ( 值 域 ) ) O(nlog^2nlog(值域)) O(nlog2nlog(值域))
空间复杂度: 节点总数为 N N N, 线段树节点为 S E G N = 4 ∗ N SEGN = 4 * N SEGN=4∗N. 因此平衡树的节点数为: S E G N ∗ 2 + N ∗ l o g N SEGN * 2 + N * logN SEGN∗2+N∗logN.
思路二: 整体二分 (树套树的题怎么能不写整体二分呢?)
这里所说的整体二分指的是: 二分值域. 我们还是对于五个操作进行分析.
操作①: 相当于找有多少个数比 k k k小, 对于当前 m i d mid mid, 如果比 k k k小, 则在答案里累加上 [ l , m i d ] [l, mid] [l,mid]区间的数字个数.
操作②: 这算是整体二分求区间第k小数的标准问法了, 每次看看 [ l , m i d ] [l, mid] [l,mid]区间的数字个数是否 ≥ k \ge k ≥k, 若满足则去左区间, 反之则减去左区间贡献后, 去右区间求解.
操作③: 整体二分的修改, 把之前的值减去, 把新值在加上.
操作④: 如果当前 m i d < k mid < k mid<k, 则我们用当前区间的最大值去更新答案, 然后再去右区间寻找是否存在更优解. 反之则去左区间寻找最优解.
操作⑤: 如果当前 m i d > k mid > k mid>k, 则我们用当前区间的最小值去更新答案, 然后再去左区间寻找是否存在更优解. 反之则去右区间寻找最优解.
这样分析之后, 我们发现我们需要维护 区间数字个数, 区间最小值, 区间最大值. 因此可以考虑用==线段树==来维护整体二分. 考虑到每次二分处理完左区间后, 需要复原. 因此可以考虑用一个懒标记, 表示该区间是否应当被初始化.
还有一些细节问题, 因为我们需要在树内同时维护 < m i d < mid <mid的最大值 与 > m i d > mid >mid的最小值, 而统计区间数字个数时, 是不统计大于mid的数值个数的. 因此我们的modify需要区分出是哪种情况的修改.
整体二分AC代码在这里➡️ 整体二分AC代码
对于整体二分, 这里只计算一下时间复杂度, 所有的操作都涉及到了树中的操作, 因此时间复杂度大概为 O ( n l o g n l o g ( 值 域 ) ) O(nlognlog(值域)) O(nlognlog(值域)), 但是因为用了vector的原因, 实际上常数会大很多. 因此在洛谷上如果不吸氧, 有概率被卡. 但是吸氧的话, 就快了非常多.
AC代码(大家可以选择点开思路里的AC代码, 代码是一样的)
/* 线段树套平衡树 */
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
using namespace std;
typedef long long ll;
const int N = 5E4 + 10, M = N * 4 * 2 + N * 16, INF = INT_MAX;
int w[N];
/* 平衡树部分 */
struct node {
int s[2], p, v;
int size;
void init(int _p, int _v) {
s[0] = s[1] = 0;
p = _p, v = _v; size = 1;
}
}t[M];
queue<int> nodes; //可用节点下标
void pushup(int x) { t[x].size = t[t[x].s[0]].size + t[t[x].s[1]].size + 1; }
void rotate(int x) {
int y = t[x].p, z = t[y].p;
int k = t[y].s[1] == x;
t[z].s[t[z].s[1] == y] = x, t[x].p = z;
t[y].s[k] = t[x].s[k ^ 1], t[t[x].s[k ^ 1]].p = y;
t[x].s[k ^ 1] = y, t[y].p = x;
pushup(y), pushup(x);
}
void splay(int x, int k, int& root) { //因为有多棵树
while (t[x].p != k) {
int y = t[x].p, z = t[y].p;
if (z != k) (t[y].s[1] == x) == (t[z].s[1] == y) ? rotate(y) : rotate(x);
rotate(x);
}
if (!k) root = x;
}
void insert(int c, int& root) {
int x = root, p = 0;
while (x) p = x, x = t[x].s[c > t[x].v];
t[p].s[c > t[p].v] = x = nodes.front(); nodes.pop(); t[x].init(p, c);
splay(x, 0, root);
}
int findval(int c, int root) {
int x = root;
while (x) {
if (t[x].v == c) return x;
x = t[x].s[c > t[x].v];
}
assert(0); //找不到的情况, 因为题目中保证一定找得到, 因此若找不到则返回错误.
}
void modify(int val, int c, int& root) { //把val修改为c
int x = findval(val, root);
splay(x, 0, root);
int l = t[x].s[0], r = t[x].s[1];
while (t[l].s[1]) l = t[l].s[1];
while (t[r].s[0]) r = t[r].s[0];
splay(l, 0, root), splay(r, l, root);
nodes.push(t[r].s[0]); t[r].s[0] = 0;
pushup(r), pushup(l);
insert(c, root);
}
int getnum(int c, int root) {
int x = root, res = 0;
while (x) {
if (t[x].v < c) res += t[t[x].s[0]].size + 1, x = t[x].s[1];
else x = t[x].s[0];
}
return res - 1; //需要减去哨兵-INF
}
int getpre(int c, int root) {
int x = root, res = 0;
while (x) {
if (t[x].v < c) res = x, x = t[x].s[1];
else x = t[x].s[0];
}
return t[res].v;
}
int getnext(int c, int root) {
int x = root, res = 0;
while (x) {
if (t[x].v > c) res = x, x = t[x].s[0];
else x = t[x].s[1];
}
return t[res].v;
}
/* 线段树部分 */
int L[N << 2], R[N << 2], ROOT[N << 2];
void build(int l, int r, int x = 1) {
L[x] = l, R[x] = r;
insert(-INF, ROOT[x]), insert(INF, ROOT[x]);
for (int i = l; i <= r; ++i) insert(w[i], ROOT[x]);
if (l == r) return;
int mid = l + r >> 1;
build(l, mid, x << 1), build(mid + 1, r, x << 1 | 1);
}
void update(int a, int c, int x = 1) {
modify(w[a], c, ROOT[x]); //把w[a]删掉, 插入c.
if (L[x] == R[x]) return;
int mid = L[x] + R[x] >> 1;
update(a, c, x << 1 | (a > mid));
}
int query(int l, int r, int c, int x = 1) {
if (l <= L[x] && r >= R[x]) return getnum(c, ROOT[x]);
int mid = L[x] + R[x] >> 1;
int res = 0;
if (l <= mid) res += query(l, r, c, x << 1);
if (r > mid) res += query(l, r, c, x << 1 | 1);
return res; //也可以通过 return res + (x == 1); 来直接得到正确排名
}
int query_pre(int l, int r, int c, int x = 1) {
if (l <= L[x] && r >= R[x]) return getpre(c, ROOT[x]);
int mid = L[x] + R[x] >> 1;
int res = -INF;
if (l <= mid) res = max(res, query_pre(l, r, c, x << 1));
if (r > mid) res = max(res, query_pre(l, r, c, x << 1 | 1));
return res;
}
int query_next(int l, int r, int c, int x = 1) {
if (l <= L[x] && r >= R[x]) return getnext(c, ROOT[x]);
int mid = L[x] + R[x] >> 1;
int res = INF;
if (l <= mid) res = min(res, query_next(l, r, c, x << 1));
if (r > mid) res = min(res, query_next(l, r, c, x << 1 | 1));
return res;
}
int main()
{
rep(i, M - 5) nodes.push(i); //初始化
int n, m; cin >> n >> m;
rep(i, n) scanf("%d", &w[i]);
build(1, n);
while (m--) {
int tp; scanf("%d", &tp);
if (tp == 1) {
int l, r, c; scanf("%d %d %d", &l, &r, &c);
//注意, query为查询有多少个数字比k小, 最终排名应该+1.
printf("%d\n", query(l, r, c) + 1);
}
else if (tp == 2) {
int l, r, k; scanf("%d %d %d", &l, &r, &k);
int L = 0, R = 0x3f3f3f3f;
while (L < R) {
int mid = L + R + 1 >> 1;
if (query(l, r, mid) + 1 <= k) L = mid;
else R = mid - 1;
}
printf("%d\n", L);
}
else if (tp == 3) {
int a, c; scanf("%d %d", &a, &c);
update(a, c);
w[a] = c; //记录原序列的修改情况, 防止重复修改a位置, 导致平衡树删除错误的值.
}
else if (tp == 4) {
int l, r, c; scanf("%d %d %d", &l, &r, &c);
printf("%d\n", query_pre(l, r, c));
}
else {
int l, r, c; scanf("%d %d %d", &l, &r, &c);
printf("%d\n", query_next(l, r, c));
}
}
return 0;
}
/* 整体二分代码 */
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
using namespace std;
typedef long long ll;
const int N = 5E4 + 10, INF = INT_MAX;
int w[N], res[N]; bool vis[N]; //vis记录是否为操作3, 因为操作3无输出.
struct node {
int l, r;
int cou, fmax, fmin;
bool flag;
}t[N << 2];
void pushdown(node& op, bool) {
op.cou = 0, op.fmax = -INF, op.fmin = INF;
op.flag = 1;
}
void pushdown(int x) {
if (!t[x].flag) return;
pushdown(t[x << 1], 1), pushdown(t[x << 1 | 1], 1);
t[x].flag = 0;
}
void pushup(node& p, node& l, node& r) {
p.cou = l.cou + r.cou;
p.fmin = min(l.fmin, r.fmin);
p.fmax = max(l.fmax, r.fmax);
}
void pushup(int x) { pushup(t[x], t[x << 1], t[x << 1 | 1]); }
void build(int l, int r, int x = 1) {
t[x] = { l, r, 0, -INF, INF, 0 };
if (l == r) return;
int mid = l + r >> 1;
build(l, mid, x << 1), build(mid + 1, r, x << 1 | 1);
}
void modify(int a, int c, int tp, int target, int x = 1) { //target为二分时的mid值
if (t[x].l == t[x].r) {
if (c <= target) {
t[x].cou += tp;
t[x].fmax = t[x].cou ? c : -INF;
}
else t[x].fmin = tp == 1 ? c : INF;
return;
}
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
modify(a, c, tp, target, x << 1 | (a > mid));
pushup(x);
}
auto ask(int l, int r, int x = 1) {
if (l <= t[x].l && r >= t[x].r) return t[x];
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
if (r <= mid) return ask(l, r, x << 1); //答案只在左区间
if (l > mid) return ask(l, r, x << 1 | 1); //答案只在右区间
node res = { NULL, NULL, NULL, -INF, INF }; //答案同时在左右区间, 需要区间合并
node L = ask(l, r, x << 1), R = ask(l, r, x << 1 | 1);
pushup(res, L, R);
return res;
}
struct operation {
int tp, l, r, k, id;
// tp a c f 操作3需要
}; vector<operation> area;
void fact(int l, int r, vector<operation>& q) {
if (q.empty()) return;
if (l == r) {
for (auto& op : q) {
if (op.tp == 1) res[op.id]++; //最后排名需要++
else if (op.tp == 2) res[op.id] = l;
}
return;
}
int mid = l + r >> 1;
vector<operation> ql, qr;
for (auto& op : q) {
if (op.tp == 1) {
if (op.k > mid) res[op.id] += ask(op.l, op.r).cou, qr.push_back(op);
else ql.push_back(op);
}
else if (op.tp == 2) {
int cou = ask(op.l, op.r).cou;
if (cou >= op.k) ql.push_back(op);
else op.k -= cou, qr.push_back(op);
}
else if (op.tp == 3) {
modify(op.l, op.r, op.k, mid);
op.r <= mid ? ql.push_back(op) : qr.push_back(op);
}
else if (op.tp == 4) {
if (mid < op.k) res[op.id] = max(res[op.id], ask(op.l, op.r).fmax), qr.push_back(op);
else ql.push_back(op);
}
else {
if (mid >= op.k) res[op.id] = min(res[op.id], ask(op.l, op.r).fmin), ql.push_back(op);
else qr.push_back(op);
}
}
pushdown(t[1], 1); //初始化
fact(l, mid, ql), fact(mid + 1, r, qr);
}
int main()
{
int n, m; cin >> n >> m;
build(1, n);
rep(i, n) {
scanf("%d", &w[i]);
area.push_back({ 3, i, w[i], 1, NULL });
}
rep(i, m) {
int tp; scanf("%d", &tp);
if (tp == 1) {
int l, r, k; scanf("%d %d %d", &l, &r, &k);
area.push_back({ 1, l, r, k, i });
}
else if (tp == 2) {
int l, r, k; scanf("%d %d %d", &l, &r, &k);
area.push_back({ 2, l, r, k, i });
}
else if (tp == 3) {
int a, c; scanf("%d %d", &a, &c);
area.push_back({ 3, a, w[a], -1, NULL });
w[a] = c;
area.push_back({ 3, a, w[a], 1, NULL });
vis[i] = 1;
}
else {
int l, r, k; scanf("%d %d %d", &l, &r, &k);
area.push_back({ tp, l, r, k, i });
res[i] = (tp == 4) ? -INF : INF;
}
}
fact(0, 1E8 + 10, area);
rep(i, m) if (!vis[i]) printf("%d\n", res[i]);
return 0;
}