【学习笔记】浅谈主席树(可持久化线段树)

前言&前置知识

主席树上没有主席,就跟老婆饼里没有老婆,Bamboo_Day 里没有 Bamboo 是一样的。
之所以叫主席树,是因为发明这个数据结构的人名字缩写为 hjt 跟某位伟大的主席一样。

连续两句话都记住了行末加句号我真是太伟大了

主席树是在线段树的基础上发展出来的,所以你要先学会线段树

当然你还要学会动态开点线段树 没错这是个空链接,等我以后写了补上

温馨提示: Bamboo_Day 的主席树不知为何常数巨大,请注意输入输出优化

引入

还是先思考一个问题:

维护一段序列,要求能单点修改区间求和,而且能查询某一历史版本的信息

第一反应应该都是线段树罢,但是历史版本怎么办?开一堆线段树吗?

其实也不是不行,但是这样太浪费了

做过 Bamboo_Day 出的 seg tree 的都知道,线段树下面的节点其实是可以挂来挂去的

什么意思呢?就是说,我们发现在建立新的线段树的时候,其实和前面一个版本的线段树就没差多少,如图
在这里插入图片描述
红色的为因为 3->4 而发生变化的节点

显然如果建立一颗新的线段树,我们就浪费了多建立没有被更改的节点的时间和空间

所以我们只需要给一个节点多几个父亲就可以愉快的解决问题力 ^ - ^
就像这样
在这里插入图片描述
把这新树拉正就是
在这里插入图片描述
这样大概就只需要重建 log ⁡ n \log n logn 个新节点 可能吧,我也不知道

实现

先看题

节点信息

很简单,和动态开点是一样的

维护 左右儿子编号,加和 然后没了

struct Node{
	int l=0,r=0;
	int sum = 0;
}tr[M*30];

值得注意的是,这道题只需要开 30 倍空间就够了,但是据说一般主席树的题要开 200 倍?反正空间尽量开够就对了,注意 MLE 和 RE 的问题

小技巧

著名的 当领导的料 曾说过:

delta 等于什么? 等于 b 2 + 4 a c b^2+4ac b2+4ac ,这个时候发现什么,我们设直线也喜欢设 y = k x + b y=kx+b y=kx+b ,这个时候,就很容易出错,所以,我们把直线改成 y + k x + m y+kx+m y+kx+m ……

以上浪费了十分钟左右的时间(实录

在线段树里面有一大堆的 lr ,确实很容易搞混,而且诸如 tr.l 这样的表达看上去非常的丑陋

那怎么办?

采用 define

#define ls(fa) tr[fa].l
#define rs(fa) tr[fa].r
#define sum(fa) tr[fa].sum

这样看上去就好看多了

push_up

和普通线段树是一样的

void push_up(int pos){
	sum(pos) = sum(ls(pos))+sum(rs(pos)); 
}

build

会了 push_up自然就可以建树了,建树也是和普通的一样的

void build(int pos, int l, int r){
	if(l == r){
		sum(pos) = a[l];
		return;
	} 
	int mid = (l+r) >> 1;
	ls(pos) = ++tot;
	rs(pos) = ++tot;
	build(ls(pos),l,mid);
	build(rs(pos),mid+1,r);
	push_up(pos); 
}

push_down

不好意思单点修改没有 push_down

update

基本思想就是,先让新节点和旧节点指向一样的东西,然后在根据修改的在哪个儿子上再对相应的儿子挂新节点

void update(int last, int now,int l, int r,int x, int y){ // l,r 当前所在的区间范围,x,y 把位置x改成y
	if(l == r){
		sum(now) = y;
		return;
	}
	ls(now) = ls(last);
	rs(now) = rs(last);
	int mid = (l+r) >> 1;
	if(x <= mid){
		ls(now) = ++tot;
		update(ls(last),ls(now),l,mid,x,y);
	}else{
		rs(now) = ++tot;
		update(rs(last),rs(now),mid+1,r,x,y);
	}
	push_up(now);
}

ask

最后查询也就是和普通线段树是一样的了

int ask(int pos, int l, int r, int x){
   if(l == r) return sum(pos);
   int mid = (l+r) >> 1;
   if(x <= mid) return ask(ls(pos),l,mid,x);
   else return ask(rs(pos),mid+1,r,x);
}

Code

对于历史版本的保存开个数组存一下根节点就 OK 了

#include <bits/stdc++.h>
const int N = 1e6+10;
const int M = 1e6+10;
#define ls(fa) tr[fa].l
#define rs(fa) tr[fa].r
#define sum(fa) tr[fa].sum

using namespace std;
struct Node{
	int l=0,r=0;
	int sum = 0;
}tr[M*30];
int tot = 1;
int root[N],a[N],n,m;
void push_up(int pos){
	sum(pos) = sum(ls(pos))+sum(rs(pos)); 
}
void update(int last, int now,int l, int r,int x, int y){
	if(l == r){
		sum(now) = y;
		return;
	}
	ls(now) = ls(last);
	rs(now) = rs(last);
	int mid = (l+r) >> 1;
	if(x <= mid){
		ls(now) = ++tot;
		update(ls(last),ls(now),l,mid,x,y);
	}else{
		rs(now) = ++tot;
		update(rs(last),rs(now),mid+1,r,x,y);
	}
	push_up(now);
}
const int up = 1e6+10;
const int down = 1;
int ask(int pos, int l, int r, int x){
	if(l == r) return sum(pos);
	int mid = (l+r) >> 1;
	if(x <= mid) return ask(ls(pos),l,mid,x);
	else return ask(rs(pos),mid+1,r,x);
}
void build(int pos, int l, int r){
	if(l == r){
		sum(pos) = a[l];
		return;
	} 
	int mid = (l+r) >> 1;
	ls(pos) = ++tot;
	rs(pos) = ++tot;
	build(ls(pos),l,mid);
	build(rs(pos),mid+1,r);
	push_up(pos); 
}
int main(){
	cin >> n >> m;
	root[0] = ++tot;
	for(int i = 1; i <= n; i++){
//		cin >> a[i];
		scanf("%d",&a[i]);
	}
	build(root[0],1,n);
	for(int i = 1;i <= m; i++){
		int v,op,x,y;
//		cin >> v >> op;
		scanf("%d%d",&v,&op);
		if(op == 1){
//			cin >> x >> y;
			scanf("%d%d",&x,&y); 
			root[i] = ++tot;
			update(root[v],root[i],1,n,x,y);
		}else{
//			cin >> x;
			scanf("%d",&x);
			root[i] = root[v];
			printf("%d\n",ask(root[i],1,n,x));
		}
	}
	return 0;
}

应用

主席树其实本质上还是个权值线段树,较多用于解决 区间 kth 问题

怎么个着呢

忘记给题了

把整段序列看作一个时间轴,对每一个节点建立一颗线段树,然后你就拥有了查询 [ 1 , r ] [1,r] [1,r] 的 kth 的能力

不会 kth 的自己去学

那么 [ l , r ] [l,r] [l,r] 的 kth 怎么跑呢,其实也简单,把 l l l r r r 一起跑,把 s u m ( r ) − s u m ( l ) sum(r) - sum(l) sum(r)sum(l) 就是 [ l , r ] [l,r] [l,r] 的信息了

Code

#include <bits/stdc++.h>
const int N = 2e5+10;
const int M = 1e6+10;
#define ls(fa) tr[fa].l
#define rs(fa) tr[fa].r
#define sum(fa) tr[fa].sum

using namespace std;
struct Node{
	int l=0,r=0;
	int sum = 0;
}tr[M*30];
int tot = 1;
int root[N],a[N],n,m;
void push_up(int pos){
	sum(pos) = sum(ls(pos))+sum(rs(pos)); 
}
void update(int last, int now,int l, int r,int x, int y){
	if(l == r){
		sum(now) = sum(last)+y;
		return;
	}
	ls(now) = ls(last);
	rs(now) = rs(last);
	int mid = (l+r) >> 1;
	if(x <= mid){
		ls(now) = ++tot;
		update(ls(last),ls(now),l,mid,x,y);
	}else{
		rs(now) = ++tot;
		update(rs(last),rs(now),mid+1,r,x,y);
	}
	push_up(now);
}
const int up = 1e9+5;
const int down = -(1e9+5);
int kth(int last, int now, int l, int r, int k){
	if(l == r) return l;
	int mid = (l+r) >> 1;
	int val = sum(ls(now)) - sum(ls(last));
	if(val >= k) return kth(ls(last),ls(now),l,mid,k);
	else return kth(rs(last),rs(now),mid+1,r,k-val);
}
int main(){
	cin >> n >> m;
	for(int i = 1; i <= n; i++) cin >> a[i];
	for(int i = 1;i <= n; i++){
		root[i] = ++tot;
		update(root[i-1],root[i],down,up,a[i],1);
	}
	while(m--){
		int l,r,k;
		cin >> l >> r >> k;
		cout << kth(root[l-1],root[r],down,up,k) << endl;
	}
	return 0;
}

习题

P1383 高级打字机

#include <bits/stdc++.h>
const int N = 1e5+10;
#define ls(pos) tr[pos].l
#define rs(pos) tr[pos].r
#define mid ((l+r)>>1)
#define len(pos) tr[pos].len
const int up = 1e6;
const int down = 1;
using namespace std;
struct Node{
	int l=0,r=0;
	char ch;
	int len = 0;
}tr[N*40];
int tot = 0;
int rt[N];
void push_up(int pos){
	len(pos) = len(rs(pos)) + len(ls(pos));
}
void modify(int last, int now, int l, int r, int x, char y){
	if(l == r){
		tr[now].len = 1;
		tr[now].ch = y;
		return;
	}
	len(now) = len(last);
	ls(now) = ls(last);
	rs(now) = rs(last);
	if(x <= mid){
		ls(now) = ++tot;
		modify(ls(last),ls(now),l,mid,x,y);
	}else{
		rs(now) = ++tot;
		modify(rs(last),rs(now),mid+1,r,x,y);
	}
	push_up(now);
}
char ask(int pos, int l, int r, int x){
	if(l == r){
		return tr[pos].ch;
	}
	if(x <= mid){
		return ask(ls(pos),l,mid,x);
	}else{
		return ask(rs(pos),mid+1,r,x);
	}
}
int n;

int main(){
	cin >> n;
	int now = 1;
	while(n--){
		char op;
		cin >> op;
		if(op == 'T'){
			rt[now] = ++tot;
			char x;
			cin >> x;
			modify(rt[now-1],rt[now],down,up,tr[rt[now-1]].len+1,x);
			now++;
		}
		if(op == 'U'){
			int x;
			cin >> x;
			rt[now] = rt[now-x-1];
			now++;
		}
		if(op == 'Q'){
			int x;
			cin >> x;
			cout << ask(rt[now-1],down,up,x) << endl;
		}
	}	
	return 0;
} 

Destiny

#include <bits/stdc++.h>
const int N = 5e5+10;

using namespace std;
//struct Node{
//	int l,r,sum;
//}tr[N*40];
inline char gc(){
    static char now[1<<16],*S,*T;
    if (T==S){T=(S=now)+fread(now,1,1<<16,stdin);if (T==S) return EOF;}
    return *S++;
}
inline int read(){
    register int x=0,f=1;char ch=gc();
    while (!isdigit(ch)){if(ch=='-')f=-1;ch=gc();}
    while (isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=gc();}
    return (f==1)?x:-x;
}
int ll[N*40],rr[N*40],sumsum[N*40];

int rt[N];
int tot = 0;
void push_up(int pos){
	sumsum[pos] = sumsum[ll[pos]] + sumsum[rr[pos]];
}
void build(int last, int &now, int l, int r,int x){
	now = ++tot;
	ll[now] = ll[last];
	rr[now] = rr[last];
	sumsum[now] = sumsum[last];
	if(l == r){
		sumsum[now]++;
		return;
	}
	int mid = (l+r)>>1;
	if(x <= mid){
		build(ll[last],ll[now],l,mid,x);
	}else{
		build(rr[last],rr[now],mid+1,r,x);
	}
	push_up(now);
} 
int kth(int L, int R, int l, int r,int k){
	if(l == r){
		return l;
	}
	int mid = (l+r)>>1;
	if(sumsum[ll[R]] - sumsum[ll[L]] < k) return kth(rr[L],rr[R],mid+1,r,k-(sumsum[ll[R]] - sumsum[ll[L]]));
	else return kth(ll[L],ll[R],l,mid,k);
}
int query(int L, int R, int l, int r,int x){
	if(l == r){
		return sumsum[R] - sumsum[L];
	}
	int mid = (l+r)>>1;
	if(x <= mid) return query(ll[L],ll[R],l,mid,x);
	else return query(rr[L],rr[R],mid+1,r,x);
}
int a[N];
int n,m;

//void ins(int l,int r,int p,int &x,int pre){
//	sumsum[x=++tot]=sumsum[pre];
//	if(l==r)return sumsum[x]++,void();
//	int m=l+r>>1;
//	if(p<=m)ins(l,m,p,ll[x],ll[pre]),rr[x]=rr[pre];
//	else ins(m+1,r,p,rr[x],rr[pre]),ll[x]=ll[pre];
//	push_up(x);
//}
//int kth(int x,int y,int l,int r,int k){
//	if(l==r)return l;
//	int m=l+r>>1,sz=sumsum[ll[y]]-sumsum[ll[x]];
//	if(sz<k)return kth(rr[x],rr[y],m+1,r,k-sz);
//	return kth(ll[x],ll[y],l,m,k);
//}
//int query(int x,int y,int l,int r,int p){
//	if(l==r)return sumsum[y]-sumsum[x];
//	int m=l+r>>1;
//	if(p<=m)return query(ll[x],ll[y],l,m,p);
//	return query(rr[x],rr[y],m+1,r,p);
//}
int main(){
//	ios::sync_with_stdio(0);
//	cin.tie(0);
	n = read();
	m = read();
	for(int i = 1;i <= n; i++){
//		cin >> a[i];
//		scanf("%d",&a[i]);
		build(rt[i-1],rt[i],1,n,read());
//		ins(1,n,read(),rt[i],rt[i-1]);
	}
	while(m--){
		int l = read(),r = read(),k = read();
//		cin >> l >> r >> k;
		int rk = 1,ans = -1,nd = (r-l+1)/k+1;
		while(rk <= r-l+1){
			int q = kth(rt[l-1],rt[r],1,n,rk);
			if(query(rt[l-1],rt[r],1,n,q) >= nd){
				ans = q;
				break;
			}
			rk += nd;
		} 
//		cout << ans << '\n';	
		printf("%d\n",ans);			
	}
	return 0;
}

P7252 棒棒糖

#include <bits/stdc++.h>
const int N = 5e5+10;
const int up = 5e4+10;
using namespace std;
struct Node{
	int l,r,sum;
}tr[N*40];
#define ls(pos) tr[pos].l
#define rs(pos) tr[pos].r
#define sum(pos) tr[pos].sum
int rt[N];
int tot = 0;
void push_up(int pos){
	sum(pos) = sum(ls(pos)) + sum(rs(pos));
}
void build(int last, int &now, int l, int r,int x){
	now = ++tot;
	ls(now) = ls(last);
	rs(now) = rs(last);
	sum(now) = sum(last);
	if(l == r){
		sum(now)++;
		return;
	}
	int mid = (l+r)>>1;
	if(x <= mid){
		build(ls(last),ls(now),l,mid,x);
	}else{
		build(rs(last),rs(now),mid+1,r,x);
	}
	push_up(now);
} 
int kth(int L, int R, int l, int r,int k){
	if(l == r){
		return l;
	}
	int mid = (l+r)>>1;
	if(sum(ls(R)) - sum(ls(L)) < k) return kth(rs(L),rs(R),mid+1,r,k-(sum(ls(R)) - sum(ls(L))));
	else return kth(ls(L),ls(R),l,mid,k);
}
int query(int L, int R, int l, int r,int x){
	if(l == r){
		return sum(R) - sum(L);
	}
	int mid = (l+r)>>1;
	if(x <= mid) return query(ls(L),ls(R),l,mid,x);
	else return query(rs(L),rs(R),mid+1,r,x);
}
int a[N];
int n,m;

int main(){
	cin >> n >> m;
	for(int i = 1;i <= n; i++){
		cin >> a[i];
		build(rt[i-1],rt[i],1,up,a[i]);
	}
	while(m--){
		int l,r;
		cin >> l >> r;
		if((r-l+1)&1) {
			int tmp = kth(rt[l-1],rt[r],1,up,(r-l+2)/2);
			if(query(rt[l-1],rt[r],1,up,tmp)>(r-l+1)/2)
				printf("%d\n",tmp);
			else puts("0");
		} else {
			int tmp = kth(rt[l-1],rt[r],1,up,(r-l+1)/2),
				tnp = kth(rt[l-1],rt[r],1,up,(r-l+3)/2);
			if(tmp==tnp && query(rt[l-1],rt[r],1,up,tmp) > (r-l+1)/2) 
				printf("%d\n",tmp);
			else printf("0\n");
		}		
	}
	return 0;
}

完结撒花

有兴趣的可以去学一学 线段树分裂 和 线段树合并

  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Bamboo_Day

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值