介绍
以下转自:大佬
还有知乎有一篇比较好的介绍: 浅谈权值线段树到主席树
前言
据说主席树这个名字的由来呢,是因为创始人的名字缩写hjt与某位相同,然后他因为不会划分树于是自创了这一个数据结构。好强啊orz
主席树能实现什么操作呢?最经典的就是查询区间第k小了,其他的还有诸如树上路径第k小啦,带修改第k小啦之类的。以静态区间第k小为例
定义
先贴一下某神犇对主席树的理解:所谓主席树呢,就是对原来的数列[1…n]的每一个前缀[1…i](1≤i≤n)建立一棵线段树,线段树的每一个节点存某个前缀[1…i]中属于区间[L…R]的数一共有多少个(比如根节点是[1…n],一共i个数,sum[root] = i;根节点的左儿子是[1…(L+R)/2],若不大于(L+R)/2的数有x个,那么sum[root.left] = x)。若要查找[i…j]中第k大数时,设某结点x,那么x.sum[j] - x.sum[i - 1]就是[i…j]中在结点x内的数字总数。而对每一个前缀都建一棵树,会MLE,观察到每个[1…i]和[1…i-1]只有一条路是不一样的,那么其他的结点只要用回前一棵树的结点即可,时空复杂度为O(nlogn)。
然而没有什么用,因为感觉根本没看懂
然后来说说我自己的理解吧。如何求出一个区间内第k小呢?直接sort当然可以,但是复杂度爆表。于是我们可以换一个思路,能否将[l,r]之间出现过的数都建成线段树呢?设节点为p,区间为[l,r],左儿子是[l,mid],右儿子是[mid+1,r]
要查找第k大的话,先看左儿子里有多少个数(表示小于等于mid的数的个数),如果大于k,进左子树找,否则令k−=左儿子数的个数,进右子树找
先来考虑一个序列:3,2,1,4
建完树之后是这样的
然后要查第2大,一下子就能发现是2了
(上面画的可能不是很严谨,大家将就下)
但我们不可能对每一个区间都建一棵树,那样的话空间复杂度绝对爆炸
然后可以转化一下思路:前缀和
区间[l,r]中小于等于mid的数的个数,可以转换为[1,r]中小于等于mid的数的个数减去[1,l−1]中小于等于mid的数的个数
于是我们只要对每一个前缀建一棵树即可
然后空间复杂度还是爆炸
然而我们又发现,区间[1,l−1]的树和区间[1,l]的树最多只会有logn个节点不同(因为每次新插入一个节点最多只会更新logn个节点),有许多空间是可以重复利用的
只要能将这些空间重复利用起来,就可以解决空间的问题了
还是上面那个序列:3,2,1,4
一开始先建一棵空树,然后一个个把每一个节点加进去
如果要看图的话可以点这里
这个时候有人就要问了,万一序列的数字特别大呢?
当然是离散化
将这些所有值离散一下就行了,可以保证所有数在1 n之间
然而感觉讲太多也没啥用……上代码好了,有详细的注释
例题
模板
Zoo
Description
JZ拥有一个很大的野生动物园。这个动物园坐落在一个狭长的山谷内,这个区域从南到北被划分成N个区域,每个区域都饲养着一头狮子。这些狮子从北到南编号为1,2,3,…,N。每头狮子都有一个觅食能力值Ai,Ai越小觅食能力越强。饲养员西西决定对狮子进行M次投喂,每次投喂都选择一个区间[I,J],从中选取觅食能力值第K强的狮子进行投喂。值得注意的是,西西不愿意对某些区域进行过多的投喂,他认为这样有悖公平。因此西西的投喂区间是互不包含的(即区间[1,10]不会与[3,4]或[5,10]同时存在,但可以与[9,11]或[10,20]一起)。同一区间也只会出现一次。你的任务就是算出每次投喂后,食物被哪头狮子吃掉了。
Input
第一行,两个整数N,M。
第二行,N个整数Ai。(1 ≤ Ai ≤ 2^31-1)$。
此后M行,每行描述一次投喂。第t+2行的三个数I,J,K表示在第t次投喂中,西西选择了区间[I,J]内觅食能力值第K强的狮子进行投喂。
Output
输出文件有M行,每行一个整数。第i行的整数表示在第i次投喂中吃到食物的狮子的觅食能力值。
Sample Input
7 2
1 5 2 6 3 7 4
1 5 3
2 7 1
Sample Output
3
2
Data Constraint
对于100%的数据,有1 ≤ N ≤ 10^5,1 ≤ M ≤ 5 × 10^4。
来源
JZOJ
Analysis 分析
可持久化线段树/主席树
主席树的主要思想就是:保存每次插入操作时的历史版本,以便查询区间第 k k k小。
怎么保存呢?暴力一点,每次开一棵线段树呗。
那空间还不爆掉?
那么我们分析一下,发现每次修改操作修改的点的个数是一样的。
(例如下图,修改了[1,8]中对应权值为1的结点,红色的点即为更改的点)
只更改了log(n)个结点,形成一条链,也就是说每次更改的结点数 = 树的高度。
注意主席树不能使用堆式存储法,就是说不能用x × 2,x × 2 + 1来表示左右儿子,而是应该动态开点,并保存每个节点的左右儿子编号。
所以我们只要在记录左右儿子的基础上存一下插入每个数的时候的根节点就可以持久化辣。
我们把问题简化一下:每次求[1,r]区间内的k小值。
怎么做呢?只需要找到插入r时的根节点版本,然后用普通权值线段树做就行了,如果不会用普通权值线段树做的话请参见开头部分的解释。
那么这个相信大家很简单都能理解,把问题扩展到原问题——求[l,r]区间k小值。
这里我们再联系另外一个知识理解:前缀和。
它运用了区间减法的性质,通过预处理从而达到O(1)回答每个询问。
那么我们主席树也行!如果需要得到[l,r]的统计信息,只需要用[1,r]的信息减去[1,l - 1]的信息就行了(请好好地想一想是不是如此)
那么至此,该问题解决!(完结撒花)
关于空间问题,我们分析一下:由于我们是动态开点的,所以一棵线段树只会出现2n-1个结点。然后,有n次修改,每次增加log(n)个结点。那么最坏情况结点数会达到2n-1+nlog(n),那么此题的n ≤ 10^5,通过计算得到 l o g 2 ( 1 0 5 ) log2^{(10^5)} log2(105)≈17,那么把n和log的结果代入这个式子,变成 2 × 1 0 5 − 1 + 17 × 1 0 5 2 × 10^5-1+17 × 10^5 2×105−1+17×105,忽略掉-1,大概就是19 × 10^5。
算了这么一大堆,I tell you: 千万不要吝啬空间!保守一点,直接上个2^5 × 10^5 = 32 × 10^5,接近原空间的两倍(即n << 5)。
(较真的同学请注意,如果你真的很吝啬,可以自己造个数据输出一下结点数量,但是如果数据没造好把自己卡掉了就尴尬了)
P.S: 实测该题需要开到20n+10个结点,19n+10会Wonderful Answer 80pts,该程序对于N = 10^5的数据开到了1968911个结点,大于19n+10。
代码:
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int maxn = 1e5;//数据范围
int tot,n,m;
int sum[(maxn << 5) + 10],rt[maxn + 10],ls[(maxn << 5) + 10],rs[(maxn << 5) + 10];
int a[maxn + 10],ind[maxn + 10],len;
inline int getid(const int &val)//离散化
{
return lower_bound(ind + 1,ind + len + 1,val) - ind;
}
int build(int l,int r)//建树
{
int root = ++tot;
if(l == r)
return root;
int mid = l + r >> 1;
ls[root] = build(l,mid);
rs[root] = build(mid + 1,r);
return root;//返回该子树的根节点
}
int update(int k,int l,int r,int root)//插入操作
{
int dir = ++tot;
ls[dir] = ls[root],rs[dir] = rs[root],sum[dir] = sum[root] + 1;
if(l == r)
return dir;
int mid = l + r >> 1;
if(k <= mid)
ls[dir] = update(k,l,mid,ls[dir]);
else
rs[dir] = update(k,mid + 1,r,rs[dir]);
return dir;
}
int query(int u,int v,int l,int r,int k)//查询操作
{
int mid = l + r >> 1,x = sum[ls[v]] - sum[ls[u]];//通过区间减法得到左儿子的信息
if(l == r)
return l;
if(k <= x)//说明在左儿子中
return query(ls[u],ls[v],l,mid,k);
else//说明在右儿子中
return query(rs[u],rs[v],mid + 1,r,k - x);
}
inline void init()
{
scanf("%d%d",&n,&m);
for(register int i = 1;i <= n;++i)
scanf("%d",a + i);
memcpy(ind,a,sizeof ind);
sort(ind + 1,ind + n + 1);
len = unique(ind + 1,ind + n + 1) - ind - 1;
rt[0] = build(1,len);
for(register int i = 1;i <= n;++i)
rt[i] = update(getid(a[i]),1,len,rt[i - 1]);
}
int l,r,k;
inline void work()
{
while(m--)
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",ind[query(rt[l - 1],rt[r],1,len,k)]);//回答询问
}
}
int main()
{
init();
work();
return 0;
}
权值线段树例题
Weight Tarot 权值塔罗牌
Description 题目描述
最近小L收集了一套塔罗牌,每个塔罗牌上面有一个权值,不超过10^5,现有N个操作,分别有以下三种情况:
Add x 如果当前的手牌中没有权值为x的塔罗牌则加入,否则忽略该操作。
Remove x 如果当前的手牌中有权值为x的塔罗牌则弃掉该牌,否则忽略该操作。
Query 查询当前手牌中权值最接近的两张牌的权值之差,如果当前手牌数量少于2张牌,输出-1。
Input Format 输入格式
第一行,一个整数 N。
接下来 N 行,每行一个操作。
Output Format 输出格式
对于每个Query操作,输出一行,表示操作的结果。
Sample 样例
Sample Input 样例输入
12
Add 1
Remove 2
Add 1
Query
Add 7
Query
Add 10
Query
Add 5
Query
Remove 7
Query
Sample Output 样例输出
-1
6
3
2
4
Data Constraint 数据范围
N ≤ 10^5
Source 来源
改编自学校OJ一题
Analysis 分析
注意,对于各个操作的描述,可以看出塔罗牌的权值不能重复!
那么,就可以构造一棵根节点区间为[1,10^5]的线段树,在树上乱搞,这就是权值线段树的核心思想——(划重点)以数据范围为区间进行答案的维护(这句话是自己说的说错了别打我(逃)
大概是这样的:
对于每个结点,我们维护三个值:min,max,diff,分别代表区间最小值,区间最大值,区间最小差。
递归维护的时候,
min=min{lson.min,rson.min},max=max{lson.max,rson.max}
这个不难理解,但是最小差呢……
思考一下,容易发现更新的时候有三种状态: - 左儿子 [ l , m i d ] [l,mid] [l,mid]区间维护的最小差
右儿子[mid + 1,r]区间维护的最小差
右儿子区间维护的最小值与左儿子区间维护的最大值的差
然后对于三种情况,min一下就可以了(注意左儿子和右儿子都可能没有数,这时就需要去除对应的情况)
一个技巧:建树的时候可以把最小值初始化为正无穷,把最大值初始化为负无穷,这样更新最小差的时候只需要对第三种状态判断左右最值是否为无穷即可,因为对无穷作min是无用的。
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#define lson (root << 1)
#define rson (lson | 1)
using namespace std;
int n,x;
char opt[15];
struct node
{
int l,r;
int max,min,diff;
} seg[400010];
void build(int l,int r,int root = 1)//初始建树操作
{
seg[root].min = seg[root].diff = 0x3f3f3f3f;
seg[root].max = -0x3f3f3f3f;
seg[root].l = l,seg[root].r = r;
//以上初值不多说
if(l == r)
return ;
int mid = l + r >> 1;
build(l,mid,lson);
build(mid + 1,r,rson);
}
void generate(int x,int root = 1)//加牌操作
{
if(seg[root].l == seg[root].r)
{
seg[root].max = seg[root].min = x;
return ;
}
int mid = seg[root].l + seg[root].r >> 1;
if(x <= mid)
generate(x,lson);
if(x > mid)
generate(x,rson);
seg[root].max = max(seg[lson].max,seg[rson].max);
seg[root].min = min(seg[lson].min,seg[rson].min);
if(seg[lson].max == -0x3f3f3f3f || seg[rson].min == 0x3f3f3f3f)
seg[root].diff = min(seg[lson].diff,seg[rson].diff);
else
seg[root].diff = min(seg[rson].min - seg[lson].max,min(seg[lson].diff,seg[rson].diff));
}
void remove(int x,int root = 1)//弃牌操作
{
if(seg[root].l == seg[root].r)
{
seg[root].max = -0x3f3f3f3f;
seg[root].min = 0x3f3f3f3f;
return ;
}
int mid = seg[root].l + seg[root].r >> 1;
if(x <= mid)
remove(x,lson);
if(x > mid)
remove(x,rson);
seg[root].max = max(seg[lson].max,seg[rson].max);
seg[root].min = min(seg[lson].min,seg[rson].min);
if(seg[lson].max == -0x3f3f3f3f || seg[rson].min == 0x3f3f3f3f)
seg[root].diff = min(seg[lson].diff,seg[rson].diff);
else
seg[root].diff = min(seg[rson].min - seg[lson].max,min(seg[lson].diff,seg[rson].diff));
}
int main()
{
scanf("%d",&n);
build(1,100000);//数据范围
while(n--)
{
scanf("%s",opt);
if(!strcmp(opt,"Add"))
{
scanf("%d",&x);
generate(x);
}
else if(!strcmp(opt,"Remove"))//else if减少多余的strcmp调用
{
scanf("%d",&x);
remove(x);
}
else if(!strcmp(opt,"Query"))
printf("%d\n",seg[1].diff == 0x3f3f3f3f ? -1 : seg[1].diff);//注意判断-1情况
}
}
以区间第k小为例 洛谷p3834
代码:
//minamoto
#include<bits/stdc++.h>
#define N 200005
using namespace std;
inline int read(){
#define num ch-'0'
char ch;bool flag=0;int res;
while(!isdigit(ch=getchar()))
(ch=='-')&&(flag=true);
for(res=num;isdigit(ch=getchar());res=res*10+num);
(flag)&&(res=-res);
#undef num
return res;
}
int sum[N<<5],L[N<<5],R[N<<5];
int a[N],b[N],t[N];
int n,q,m,cnt=0;
int build(int l,int r){
int rt=++cnt;
//建树
sum[rt]=0;
if(l<r){
int mid=(l+r)>>1;
L[rt]=build(l,mid);
R[rt]=build(mid+1,r);
}
return rt;
}
int update(int last,int l,int r,int x){
int rt=++cnt;
L[rt]=L[last],R[rt]=R[last],sum[rt]=sum[last]+1;
//先继承上一次的信息
//L是左节点,R是右节点,sum是节点内数的个数
if(l<r){
int mid=(l+r)>>1;
if(x<=mid) L[rt]=update(L[last],l,mid,x);
else R[rt]=update(R[last],mid+1,r,x);
//如果有需要更新的信息,更新
//可以发现每一次更新的节点最多只有log n个
}
return rt;
}
int query(int u,int v,int l,int r,int k){
if(l>=r) return l;
int x=sum[L[v]]-sum[L[u]];
//查询操作
int mid=(l+r)>>1;
if(x>=k) return query(L[u],L[v],l,mid,k);
else return query(R[u],R[v],mid+1,r,k-x);
//如果左节点个数大于等于k,进左子树找第k小
//否则进右子树
}
int main(){
//freopen("testdata.in","r",stdin);
n=read(),q=read();
for(int i=1;i<=n;++i)
b[i]=a[i]=read();
sort(b+1,b+1+n);
m=unique(b+1,b+1+n)-b-1;
t[0]=build(1,m);
//先建一棵空树
for(int i=1;i<=n;++i){
int k=lower_bound(b+1,b+1+m,a[i])-b;
//离散
t[i]=update(t[i-1],1,m,k);
//然后每次在上一次的基础上建树
}
while(q--){
int x,y,z;
x=read(),y=read(),z=read();
int k=query(t[x-1],t[y],1,m,z);
printf("%d\n",b[k]);
}
return 0;
}
如果熟练了之后,可以发现其实第一步的建树过程是可以省略的,直接每一步加节点就行了
//minamoto
#include<bits/stdc++.h>
#define N 200005
using namespace std;
inline int read(){
#define num ch-'0'
char ch;bool flag=0;int res;
while(!isdigit(ch=getchar()))
(ch=='-')&&(flag=true);
for(res=num;isdigit(ch=getchar());res=res*10+num);
(flag)&&(res=-res);
#undef num
return res;
}
int sum[N<<5],L[N<<5],R[N<<5];
int a[N],b[N],t[N];
int n,q,m,cnt=0;
void update(int last,int &now,int l,int r,int x){