主席树总结

之前一直以为主席树是个什么神仙玩意儿,后面看了下其实也不是很难。主席树也被称作可持久化线段树吧,这里的线段树一般是权值线段树,普通的权值线段树只能维护整个区间的权值信息,对于部分的区
间信息不能维护。而对于每个主席树的\(root[i]\),它存的是\(1\)~\(i\)的区间权值信息,相当于一个权值线段树的前缀和。因为某些问题区间信息具有可减性,所以可以用主席树来维护。这个之后有例题。
但对于每个前缀\(i\)都建立一颗权值线段树空间复杂度太高了,所以这里比较巧妙的一个地方就是前缀\(i\)与前缀\(i-1\)会有很多的重合部分,所以我们可以共用很多结点。那么每次插入前缀\(i\)时,共用的部分不>管,对于其余的结点我们新开一个结点就是了,所以每次最多就会开一条链。那么最后总的空间复杂度就很低了,为\(O(nlogn+nlogn)\)。时间复杂度也很低,为\(O(nlogn)\)

 

部分代码

一开始的时候,我们会建立一颗空树,之后会以此为基础来插入前缀。

void build(int &o, int l, int r) {
    o = ++T;
    if(l == r) {
        return ;
    }
    int mid = (l + r) >> 1;
    build(ls[o], l, mid) ;
    build(rs[o], mid + 1, r) ;
}

 
之后就是对于每个前缀\(i\)的插入了。

void update(int &o, int l, int r, int last, int p) {
    o = ++T;
    ls[o] = ls[last] ;
    rs[o] = rs[last] ;//共用结点
    if(l == r) {
        sum[o] = sum[last] + 1;
        return ;
    }
    int mid = (l + r) >> 1;
    if(p <= mid) update(ls[o], l, mid, ls[last], p) ;
    else update(rs[o], mid + 1, r, rs[last], p);
    pushup(o) ;
}
//main中
for(int i = 1; i <= n; i++) update(rt[i], 1, n, rt[i - 1], a[i]) ;

 
对于询问(以区间第k大为例):

int query(int L, int R, int l, int r, int k) {
    if(l == r) return l;
    int s = sum[ls[R]] - sum[ls[L]] ;
    int mid = (l + r) >> 1 ;
    if(s >= k) return query(ls[L], ls[R], l, mid, k) ;
    else return query(rs[L], rs[R], mid + 1, r, k - s) ;
}

 

例题

 


这是主席树基本用途之一,查询的时候通过比较区间权值数量来决定进入左子树还是右子树,具体见代码吧。
代码如下:

Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5, M = 5005;
int n, m, T;
int a[N], b[N], c[N], d[N];
int rt[N] ;
int sum[N * 20], ls[N * 20], rs[N * 20];
void pushup(int o) {
    sum[o] = sum[ls[o]] + sum[rs[o]] ;
}
void build(int &o, int l, int r) {
    o = ++T;
    if(l == r) {
        return ;
    }
    int mid = (l + r) >> 1;
    build(ls[o], l, mid) ;
    build(rs[o], mid + 1, r) ;
}
void update(int &o, int l, int r, int last, int p) {
    o = ++T;
    ls[o] = ls[last] ;
    rs[o] = rs[last] ;
    if(l == r) {
        sum[o] = sum[last] + 1;
        return ;
    }
    int mid = (l + r) >> 1;
    if(p <= mid) update(ls[o], l, mid, ls[last], p) ;
    else update(rs[o], mid + 1, r, rs[last], p);
    pushup(o) ;
}
int query(int L, int R, int l, int r, int k) {
    if(l == r) return l;
    int s = sum[ls[R]] - sum[ls[L]] ;
    int mid = (l + r) >> 1 ;
    if(s >= k) return query(ls[L], ls[R], l, mid, k) ;
    else return query(rs[L], rs[R], mid + 1, r, k - s) ;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> a[i], b[i] = c[i] = a[i];
    sort(b + 1, b + n + 1) ;
    int D = unique(b + 1, b + n + 1) - b - 1;
    for(int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + D + 1, a[i]) - b, d[a[i]] = c[i];
    build(rt[0], 1, D) ;
    for(int i = 1; i <= n; i++) update(rt[i], 1, D, rt[i - 1], a[i]) ;
    for(int i = 1; i <= m; i++) {
        int l, r, k ;
        cin >> l >> r >> k ;
        int ans = query(rt[l - 1], rt[r], 1, D, k) ;
        cout << d[ans] << '\n';
    }
    return 0;
}

 


可以用这个加深一下对可持久化的理解吧。
代码如下:

Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 5 ;
int n, m;
int a[N];
int rt[N], tre[N * 20], lc[N * 20], rc[N * 20];
int T;
void build(int &o, int l, int r) {
    o = ++T;
    if(l == r) {
        tre[o] = a[l] ;
        return ;
    }
    int mid = (l + r) >> 1;
    build(lc[o], l, mid) ;
    build(rc[o], mid + 1, r) ;
}
void update(int &o, int l, int r, int last, int sign, int p, int v) {
    o = ++T;
    lc[o] = lc[last];
    rc[o] = rc[last];
    if(sign == 0)
        return ;
    if(l == r) {
        tre[o] =  v;
        return ;
    }
    int mid = (l + r) >> 1;
    if(p <= mid) update(lc[o], l, mid, lc[last], sign, p, v) ;
    else update(rc[o], mid + 1, r, rc[last], sign, p, v) ;
}
int query(int o, int l, int r, int p) {
    if(l == r) return tre[o] ;
    int mid = (l + r) >> 1;
    if(p <= mid) return query(lc[o], l, mid, p) ;
    else return query(rc[o], mid + 1, r, p) ;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n >> m ;
    for(int i = 1; i <= n; i++) cin >> a[i] ;
    build(rt[0], 1, n) ;
    for(int i = 1; i <= m; i++) {
        int v, op, pos, val;
        cin >> v >> op >> pos ;
        if(op == 1) {
            cin >> val ;
            update(rt[i], 1, n, rt[v], 1, pos, val) ;
        } else {
            int ans = query(rt[v], 1, n, pos) ;
            update(rt[i], 1, n, rt[v], 0, pos, 0);
            cout << ans << '\n';
        }
    }
    return 0;
}

 


题意是询问区间中是否有超过区间长度一半的数,有的话这个数是哪个。
维护一下个数,然后根据sum看是否进入左右子树来找就行了,如果最后没有则输出0。
代码如下:

Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
#include <map>
using namespace std;
typedef long long ll;
const int N = 5e5 + 5;
int n, m, T;
int a[N];
int rt[N] ;
int sum[N * 40], ls[N * 40], rs[N * 40];
void pushup(int o) {
    sum[o] = sum[ls[o]] + sum[rs[o]] ;
}
void build(int &o, int l, int r) {
    o = ++T;
    if(l == r) {
        return ;
    }
    int mid = (l + r) >> 1;
    build(ls[o], l, mid) ;
    build(rs[o], mid + 1, r) ;
}
void update(int &o, int l, int r, int last, int p) {
    o = ++T;
    ls[o] = ls[last] ;
    rs[o] = rs[last] ;
    if(l == r) {
        sum[o] = sum[last] + 1;
        return ;
    }
    int mid = (l + r) >> 1;
    if(p <= mid) update(ls[o], l, mid, ls[last], p) ;
    else update(rs[o], mid + 1, r, rs[last], p);
    pushup(o) ;
}
int query(int L, int R, int l, int r, int k) {
    if(l == r) return l;
    int lsize = sum[ls[R]] - sum[ls[L]], rsize = sum[rs[R]] - sum[rs[L]];
    int mid = (l + r) >> 1 ;
    if(lsize >= k) return query(ls[L], ls[R], l, mid, k) ;
    else if(rsize >= k) return query(rs[L], rs[R], mid + 1, r, k) ;
    else return 0;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> a[i];
    build(rt[0], 1, n) ;
    for(int i = 1; i <= n; i++) update(rt[i], 1, n, rt[i - 1], a[i]) ;
    for(int i = 1; i <= m; i++) {
        int l, r;
        cin >> l >> r;
        int mid = ((r - l + 1) >> 1) + 1;
        int v = query(rt[l - 1], rt[r], 1, n, mid) ;
        cout << v << '\n';
    }
    return 0;
}

 


题目给出一棵树,然后每条边都有一定的边权。然后有多个询问,对于每个询问给出 \(u,v,k\),回答从 \(u\)\(v\)的路径中,权值不大于 \(k\)的边有多少条。

这个题的做法还是挺多的,可以树剖,也可以直接用树状数组,我就说说主席树的做法吧。
首先因为是树上的路径,那么我们肯定是要求LCA的。然后我们对整颗数dfs一遍,在dfs过程中插入主席树,以当前结点的父亲结点为历史版本,那么这样我们就清晰地知道从根到当前结点路径的权值信息了。
对于每次询问,找到LCA,计算一波就行了。这里直接利用主席树找权值不超过k的个数,对于结点\(u\)\(cnt[u]\),那么最终的答案就为\(cnt[u]+cnt[v]-2*cnt[LCA]\)
代码如下:

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int n, m, T, D;
int b[N << 1];
int head[N];
struct Edge{
    int v, w, next;
}e[N << 1];
struct Q{
    int u, v, w;
}q[N];
struct edge{
    int u, v, w;
}E[N << 1];
int tot;
void adde(int u, int v, int w) {
    e[tot].v = v; e[tot].next = head[u]; e[tot].w = w; head[u] = tot++;
    e[tot].v = u; e[tot].next = head[v]; e[tot].w = w; head[v] = tot++;
}
int f[N][22], deep[N];
int ls[N * 20], rs[N * 20], rt[N], sum[N * 20];
void build(int &o, int l, int r) {
    o = ++T;
    if(l == r) return ;
    int mid = (l + r) >> 1;
    build(ls[o], l, mid) ;
    build(rs[o], mid + 1, r) ;
}
void update(int &o, int l, int r, int last, int v) {
    o = ++T;
    sum[o] = sum[last] + 1;
    ls[o] = ls[last]; rs[o] = rs[last] ;
    if(l == r) return ;
    int mid = (l + r) >> 1;
    if(v <= mid) update(ls[o], l, mid, ls[last], v) ;
    else update(rs[o], mid + 1, r, rs[last], v) ;
}
void dfs(int u, int fa) {
    for(int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].v;
        if(v == fa) continue ;
        deep[v] = deep[u] + 1;
        f[v][0] = u;
        for(int j = 1; j <= 20; j++) f[v][j] = f[f[v][j - 1]][j - 1] ;
        int w = lower_bound(b + 1, b + D + 1, e[i].w) - b;
        update(rt[v], 1, n + m, rt[u], w) ;
        dfs(v, u);
    }
}
int lca(int x, int y) {
    if(deep[x] < deep[y]) swap(x, y) ;
    for(int i = 20; i >= 0; i--)
        if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
    if(x == y) return x;
    for(int i = 20; i >= 0; i--)
        if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
    return f[x][0] ;
}
int query(int o, int l, int r, int last, int v) {
    if(l == r) return sum[o] - sum[last] ;
    int mid = (l + r) >> 1;
    int s = sum[ls[o]] - sum[ls[last]] ;
    if(v <= mid) return query(ls[o], l, mid, ls[last], v) ;
    else return s + query(rs[o], mid + 1, r, rs[last], v) ;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    memset(head, -1, sizeof(head)) ;
    cin >> n >> m;
    for(int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w ;
        E[i] = edge{u, v, w} ;
        b[i] = w ;
    }
    D = n;
    for(int i = 1; i <= m; i++) {
        int u, v, w;
        cin >> u >> v >> w ;
        q[i] = Q{u, v, w} ;
        b[D++] = w;
    }
    sort(b + 1, b + D) ;
    D = unique(b + 1, b + D) - b - 1;
    for(int i = 1; i < n; i++) {
        adde(E[i].u, E[i].v, E[i].w) ;
    }
    build(rt[1], 1, n + m) ;
    dfs(1,0) ;
    for(int i = 1; i <= m; i++) {
        int u = q[i].u, v = q[i].v, w = q[i].w;
        w = lower_bound(b + 1, b + D + 1, w) - b;
        int LCA = lca(q[i].u, q[i].v) ;
        int s1 = query(rt[u], 1, n + m, rt[1], w) ;
        int s2 = query(rt[v], 1, n + m, rt[1], w) ;
        int s3 = query(rt[LCA], 1, n + m, rt[1], w) ;
        //cout << w << ' ' << s1 << ' ' << s2 << ' ' << s3 << '\n';
        cout << s1 + s2 - 2 * s3 << '\n';
    }
    return 0 ;
}

转载于:https://www.cnblogs.com/heyuhhh/p/10753168.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值