断断续续看了许久的主席树,简单记录一下。
什么样的题目用主席树呢,比如POJ 2104 求区间第K大的数是谁?
当时做这个题时,感觉分块+二分可以搞,就写了好久,改了好久,始终TLE,还是学学主席树把= =。
先吐槽一下线段树:
线段树竟然是被一个黄嘉泰的大佬因不会划分树来代替的,,,,,因缩写是HJT取名为主席树= =!orz
主席树大体思路:
我们怎样求区间第K大数呢:
假如我们能够利用前缀和思想,把每个数都以它为根建立一棵线段树,上面能够统计第几大的数出现的次数的话,我们就可以根据像二叉平衡树那样查找第K大的数。
假如我们建立一个线段树,区间划分和正常的线段树一样,只不过区间代表的是第几大数,不在是区间上的点。
先声明一下线段树的结点:
struct Node{
int l,r,sum;
}p[maxn*40];
// L代表的是“第几大的左端点”,R代表的是“第几大的右端点”,sum 代表是1~i个数在[L,R]这个第几大区间上出现的次数。
那么我们就可以n 次更新(修改)线段树 ,根据这个数在所有数中第几大来确定往左树走还是往右树上走,直到走到叶子结点(L == R)为止。
n 次更新线段树,因为每次都是往左或者往右 这是一个log级别的次数,因为我们只需要nlogn 个空间就可以完成一个有n 个版本的线段树。
假设root[i]表示以第i 个数为根的 根节点编号。
那么我们的update函数就可以这样写:
先建立一个新结点,和root[i-1]相等,这样就可以让新结点的左右子树和root[i-1]的一样了,相当于是引用,部分引用,部分修改。然后这样修改新结点的sum变量就可以了。 这样一直走一直走,根据第i 个数的第几大来确定往左还是往右走。
int update(int l,int r,int c,int k){
int nc = ++cnt;
p[nc] = p[c];
p[nc].sum++;
int mid = l+r>>1;
if (l == r) return nc;
if (mid >= k) p[nc].l = update(l,mid,p[c].l,k);
else p[nc].r = update(mid+1,r,p[c].r,k);
return nc;
}
query查询函数:
比如说我们要查询[x,y]这个区间上第k 大的数:
刚开始我们肯定从根结点(1~n)开始。
我们先求出[x,y]区间上,第1~mid大的数有几个(假设有sum 个),如果sum >= k,那么这个第k大数肯定在根节点的左子树,否则如果sum < k,第k 大数肯定在右子树上。那么问题就是如何求一个区间上的1~mid 大的呢?
根据我们建树的性质,我们让第y 个版本的线段树目前结点(根结点)的左儿子的sum 减去 第x-1个版本的线段树目前结点(根节点)的左儿子的sum。这个差就是[x,y]区间上1~mid大数的个数。这样我们找到最后找到一个叶子结点L==R,那么这个L 就是原数组离散化后的下标。
int query(int l,int r,int x,int y,int k){
if (l == r) return l;
int mid = l + r >> 1;
int sum = p[p[y].l ].sum - p[p[x].l ].sum;
if (sum >= k) return query(l,mid,p[x].l,p[y].l,k);
else return query(mid+1,r,p[x].r,p[y].r,k-sum);
}
好了,这样区间第K大数就解决了,其实理解了,感觉很巧妙的= =。
常用的离散化方法:
sort(v.begin(),v.end());
v.erase(unique(v.begin(),v.end()),v.end());
参考代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 1e5 + 10;
int root[maxn];
vector<int>v;
struct Node{
int l,r,sum;
}p[maxn*40];
int cnt = 0;
int build(int l,int r){
int rt = ++cnt;
p[rt].sum = 0;
p[rt].l = p[rt].r = 0;
if (l == r) return rt;
int mid = l+r>>1;
p[rt].l = build(l,mid);
p[rt].r = build(mid+1,r);
return rt;
}
int a[maxn];
int getid(int x){
return lower_bound(v.begin(),v.end(),x) - v.begin() + 1;
}
int update(int l,int r,int c,int k){
int nc = ++cnt;
p[nc] = p[c];
p[nc].sum++;
int mid = l+r>>1;
if (l == r) return nc;
if (mid >= k) p[nc].l = update(l,mid,p[c].l,k);
else p[nc].r = update(mid+1,r,p[c].r,k);
return nc;
}
int query(int l,int r,int x,int y,int k){
if (l == r) return l;
int mid = l + r >> 1;
int sum = p[p[y].l ].sum - p[p[x].l ].sum;
if (sum >= k) return query(l,mid,p[x].l,p[y].l,k);
else return query(mid+1,r,p[x].r,p[y].r,k-sum);
}
int main(){
int n, q;
scanf("%d %d",&n, &q);
root[0] = build(1,n);
for (int i = 1; i <= n; ++i){
int x;
scanf("%d",&x);
a[i] = x;
v.push_back(x);
}
sort(v.begin(),v.end());
v.erase(unique(v.begin(),v.end()),v.end());
for (int i = 1; i <= n; ++i){
root[i] = update(1,n,root[i-1],getid(a[i]));
}
while(q--){
int x,y,k;
scanf("%d %d %d",&x, &y, &k);
int ans = query(1,n,root[x-1], root[y], k);
printf("%d\n",v[ans-1]);
}
return 0;
}