最详细的讲解,让你一次学会主席树

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/creatorx/article/details/75446472

好久以前就想学习主席树这个黑科技,一直觉得很难,然后平时上课也没有什么好的时间,所以一直搁置到现在,最近遇到了一个比较简单,比较经典的问题,求区间第k小,比如poj2104,没有更新操作,只有查询操作,因为最近一直在学习分块思想,既然没有更新操作,我觉得可以用分块搞一下,但是一直tle,我的大致思想是把原序列分为为若干块,然后对每一块进行块内排序,每次 查询操作就是二分区间第k小元素设为x,judge的判断根据是区间小于x的数的个数小于k,因为我们在前面进行块内排序,对于一个区间里面如果包含整块,我们对这个块进行二分来找有多少数小于x,对于非整块,我们直接暴力就行,但是这种做法就是不过不了,一直tle,网上有人用这种方法过了,但是我不知道我的为什么过不了,可能写的比较丑吧,下面是我的分块代码,希望大神们指点一下

#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 10;
int n, m;
int a[maxn];//原数组
int MAGIC;//块的大小
int b[maxn];//按块排好序的数组
int c[maxn];//整体从小到大排好序的数组,用于整体二分求第k大元素
void init()
{
    MAGIC = (int)sqrt(n);
    for(int i = 0; i < n;)
    {
        if(i + MAGIC - 1 < n)
        {
            sort(b + i, b + i + MAGIC);
        }
        i += MAGIC;
    }
}
bool judge(int l, int r, int x, int k)//在区间l,r中小于x的数的个数是否小于k
{
    int num = 0;//小于x的数的个数
    for(int i = l; i <= r;)
    {
        if(i % MAGIC == 0 && (i + MAGIC - 1) <= r) //在块内,在数组b中进行二分查找
        {
           int left = i;
           int right = i + MAGIC - 1;
           int loc = -1;
           while(left <= right)
           {
               int mid = (left + right)>>1;
               if(b[mid] >= x)
               {
                   loc = mid;
                   right = mid - 1;
               }
               else left = mid + 1;
           }
           if(loc < 0) loc = MAGIC;
           else loc -= i;
           num += loc;
           i += MAGIC;
        }
        else
        {
           if(a[i] < x) num++;
           i++;
        }

    }
    if(num < k) return true;
    else return false;
}
int solve(int l, int r, int k)//求解区间l, r的第k大值
{
    int L = 0;
    int R = n - 1;
    int re = -1;
    while(L <= R)
    {
        int mid = (L + R)>>1;
        if(judge(l, r, c[mid], k))
        {
            re = c[mid];
            L = mid + 1;
        }
        else R = mid - 1;

    }
    return re;
}
int main()
{
    freopen("C:\\Users\\creator\\Desktop\\in.txt","r",stdin) ;
    freopen("C:\\Users\\creator\\Desktop\\out.txt","w",stdout) ;
    scanf("%d%d", &n, &m);
    for(int i = 0; i < n; i++)
    {
        scanf("%d", &a[i]);
        c[i] = a[i];
        b[i] = a[i];
    }
    sort(c, c + n);
    init();
    int l, r, k;
    for(int i = 1; i <= m; i++)
    {
        scanf("%d%d%d", &l, &r, &k);
        printf("%d\n", solve(l - 1, r - 1, k));
    }
    return 0;
}

居然分块过不了,没有办法只有用主席树来怼了,于是我花了一整天的时间来学了这个黑科技,刚开始看的时候看的我一脸懵逼,最后终于是学了一点皮毛,能够勉强解决静态区间第k小问题,至于动态区间第k小问题,我会在后面学习,到时候也会写一篇博客与大家交流讨论。

首先我们来了解什么叫做主席树,下面是从其他大佬的博客中复制过来的了解一下就行

所谓主席树呢,就是对原来的数列[1..n]的每一个前缀[1..i](1≤i≤n)建立一棵线段树,线段树的每一个节点存某个前缀[1..i]中属于区间[L..R]的数一共有多少个(比如根节点是[1..n],一共i个数,sum[root] = i;根节点的左儿子是[1..(L+R)/2],若不大于(L+R)/2的数有x个,那么sum[root.left] = x)。若要查找[i..j]中第k大数时,设某结点x,那么x.sum[j] - x.sum[i - 1]就是[i..j]中在结点x内的数字总数。而对每一个前缀都建一棵树,会MLE,观察到每个[1..i]和[1..i-1]只有一条路是不一样的,那么其他的结点只要用回前一棵树的结点即可,时空复杂度为O(nlogn)。

看了什么的介绍,感觉并没有什么卵用,还是一脸懵逼,下面我一点一点来讲解主席树的原理以及实现过程和代码。

在学习主席树之前,需要你很熟悉线段树这个东西,因为主席树的主体是多颗线段树,首先我们来简单的回顾一下线段树的简单特点和性质,我们熟悉的线段树一般是用一个结构体表示一个节点,每个节点有一个编号,节点里面一般有两个变量l, r来表示这个节点维护的区间,然后还有一个区间信息(比如区间最大值,最小值,和等,视具体问题而定),当然如果涉及到区间更新,还有一个add或者lazy叫做延迟标记的东西,然后一般线段树最明显的特点就行,一个父节点的编号是i,那么他的两只儿子节点的编号分别为2 * i(左) , 2 * i + 1(右),注意主席树在这一点有别于一般的线段树,每一个父节点,他的两个儿子节点的编号不一定满足这个关系。

我们先来分析一下对于任意一个区间,我们怎样求解这个区间的第k小值,当然一个最简单的做法就是把这个区间的数都拿出来排个序,然后直接输出就好,这很简单,但是复杂度爆表,我们考虑一个线段树的做法,假如一个区间l, r我们用一个用这个区间内出现过的数的个数组成一颗线段树,这是什么意思呢,求一个区间的第k小数,当然与这个区间有多少数比他小有关,在这里我举一个例子来说明一下怎样建立这样的一颗线段树。比如这个区间表示的序列是4,1,3,2,我们要求第2小,我们一眼就看出是2,那么我们怎样上面所说的线段树来求解呢,下面我画了一幅图来讲解,其中这颗线段树上的每个节点维护的是这个节点表示区间内的个数(假设每个数都不一样)

这里写图片描述
圈内的数字表示这个区间里面有多少个数,最后叶节点表示一个数字,对应上述序列中的一个数,注意,任意一个长度为N的序列我们都可以把他离散为一个1 ,2,3,,,,N的序列,只需要简单的hash一下就行。然后这样的一颗线段树建立出来了,我们怎样寻找区间第2小,因为叶节点从左到右表示的数依次增大,根据这个性质,以及每个节点保存了区间内的数的个数这个信息,我们可以轻易的找出区间第2小,具体的找法是,从根节点开始,看左儿子里面的数的 个数是不是大于等于2,如果是则第2小一定在左子树中,于是继续找左子树,反之找右子树,直到找到叶节点为止,然后直接返回叶节点表示的值就行。

但是多次询问区间第k小,我们每次这样建立一个线段树,这样不仅空间复杂度非常之高,而且时间复杂度也非常高,甚至比普通排序还要高,那么我们只不是可以想一个办法,使得对于每次我们查询不同的区间我们不需要重新建树,如果这样。时间复杂度和空间复杂度就大大降低了。

我们很容易就联想到了前缀和的概念,比如我们有一个问题。就是每次静态的求区间和,我们可以预处理所以的前缀和sum[i],我们每次求解区间l, r和时,我们可以直接得到答案为sum[r] - sum[l -1],这样就不需要对区间中的每个数进行相加来求解了。

同样一个道理,我们也可以利用前缀和这个思想来解决建树这个问题,我们只需要建立n颗“前缀”线段树就行,第i树维护[1,i]序列,这样我们处理任意区间l, r时就可以通过处理区间[1,l - 1], [1,r],就行,然后两者的处理结果进行相加相减就行。为什么满足相加减的性质,我们简单分析一下就很容易得出。如果在区间[1,l - 1]中有x个数小于一个数,在[1,r]中有y个数小于那个数,那么在区间[l,r]中就有y - x 个数小于那个数了,这样就很好理解为什么可以相加减了,另外,每颗树的结构都一样,都是一颗叶节点为n个的线段树。

上述利用前缀和的思想只是解决了时间复杂度的问题,并没有解决空间复杂度的问题,要解决空间复杂度问题。我们需要用到线段树的性质,我们每次更新一个数,那么与更新之前相比,这颗线段树改变只是一条链(从根节点到某一叶节点),那么我们可以充分利用这个特点,因为第i颗树与第i- 1颗树相比,只是更新了第i个元素,所以这两棵树有很多相同的节点,所以这两棵树可以共用很多节点(这也是为什么主席树的中节点编号不满足儿子节点编号是父节点编号的两倍和两倍加一的原因),于是这样就解决空间复杂度问题。

说了这么多,下面是我的代码,代码中有详细的注释,你可以结合代码和上面的讲解看(下面代码是poj2104的ac代码)。

#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
const int maxn = 1e5 + 10;
int n, m;
int cnt;
struct node{
    int L, R;//分别指向左右子树
    int sum;//该节点所管辖区间范围内数的个数
    node(){
        sum = 0;
    }
}Tree[maxn * 20];
struct value{
    int x;//值的大小
    int id;//离散之前在原数组中的位置
}Value[maxn];
bool cmp(value v1, value v2)
{
    return v1.x < v2.x;
}
int root[maxn];//多颗线段树的根节点
int rank[maxn];//原数组离散之后的数组
void init()
{
    cnt = 1;
    root[0] = 0;
    Tree[0].L = Tree[0].R = Tree[0].sum = 0;
}
void update(int num, int &rt, int l, int r)
{
    Tree[cnt++] = Tree[rt];
    rt = cnt - 1;
    Tree[rt].sum++;
    if(l == r) return;
    int mid = (l + r)>>1;
    if(num <= mid) update(num, Tree[rt].L, l, mid);
    else update(num, Tree[rt].R, mid + 1, r);
}
int query(int i, int j, int k, int l, int r)
{
    int d = Tree[Tree[j].L].sum - Tree[Tree[i].L].sum;
    if(l == r) return l;
    int mid = (l + r)>>1;
    if(k <= d) return query(Tree[i].L, Tree[j].L, k, l, mid);
    else return query(Tree[i].R, Tree[j].R, k - d, mid + 1, r);
}
int main()
{
    //freopen("C:\\Users\\creator\\Desktop\\in.txt","r",stdin) ;
    //freopen("C:\\Users\\creator\\Desktop\\out.txt","w",stdout) ;
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++)
    {
        scanf("%d", &Value[i].x);
        Value[i].id = i;
    }
    //进行离散化
    sort(Value + 1, Value + n + 1, cmp);
    for(int i = 1; i <= n; i++)
    {
        rank[Value[i].id] = i;
    }
    init();
    for(int i = 1; i <= n; i++)
    {
        root[i] = root[i - 1];
        update(rank[i], root[i], 1, n);
    }
    int left, right, k;
    for(int i = 1; i <= m; i++)
    {
        scanf("%d%d%d", &left, &right, &k);
        printf("%d\n", Value[query(root[left - 1], root[right], k, 1, n)].x);
    }
    return 0;
}

展开阅读全文

没有更多推荐了,返回首页