SPOJ - PT07J Query on a tree III 主席树
You are given a node-labeled rooted tree with n nodes.
Define the query (x, k): Find the node whose label is k-th largest in the subtree of the node x. Assume no two nodes have the same labels.
Input
The first line contains one integer n (1 <= n <= 105). The next line contains n integers li (0 <= li <= 109) which denotes the label of the i-th node.
Each line of the following n - 1 lines contains two integers u, v. They denote there is an edge between node u and node v. Node 1 is the root of the tree.
The next line contains one integer m (1 <= m <= 104) which denotes the number of the queries. Each line of the next m contains two integers x, k. (k <= the total node number in the subtree of x)
Output
For each query (x, k), output the index of the node whose label is the k-th largest in the subtree of the node x.
Example
Input:
5
1 3 5 2 7
1 2
2 3
1 4
3 5
4
2 3
4 1
3 2
3 2
Output:
5
4
5
5
题意:给一棵n个结点的点权树,以1为根,询问子树x权值k大的结点标号。
解:
对子树问题通常用dfs序解决。
将子树变为数组上一段连续的区间,然后问题转化为区间求第k大值。显然暴力在不好的情况下会超时,所以考虑用数据结构去维护。
区间第k大值我们通常用主席树,又叫函数式线段树来维护。
主席树程序过程:
(1)initial:离散化点权,设离散后点权序列为s1,s2…sm,则用sum数组统计区间[1,i]点权s1~si的值(类似桶排序里的桶,且桶的个数为max{s1,s2…sn}),并建立n个线段树维护这些区间。
(2)build:以只加入s1结点的桶建一棵完整的线段树。
(3)insert:每次加入一个结点si,并在上一棵线段树的基础上建立新的线段树。
具体过程:
(3.1)si大于tree(当前访问的线段树结点)左儿子权值,则在右儿子新添加一个结点,左儿子指向上一个线段树对应区间结点的左儿子
(3.2)小于同理
(4)query:用前缀和思想同时遍历左右端点(假设区间左闭右闭,左端点已经-1)对应的线段树,用前缀和思想:对每一个值:w[l,r]=w[1,r]-w[1,l-1]。当我们要找的kth大于用左儿子计算的这个值时,说明在这段值域的区间里,只有不到k个树,我们在右儿子遍历;否则在左儿子。
注意:存储函数式线段树要开大(比如20倍)的空间(因为树结点有n*logn个)否则会SIGSEGV。
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int MAXN = 111111;
#define fs first
#define se second
struct Node{
int l, r, w;
Node(int a=0, int b=0, int c=0):l(a), r(b), w(c){}
};
struct FunctionTree{
Node tree[4000000]; //注意要开40倍的空间
int top[MAXN]; //线段树顶编号
int w[MAXN]; //离散后权重
int sum[MAXN]; //sum[i]为w等于i的数的个数
int back[MAXN]; //i对应的离散前原数
int tot; //主席树结点总数
int build(int l, int r) {
int id = ++tot;
if (l == r) {
tree[id].w = sum[l];
tree[id].l = tree[id].r = 0;
return id;
}
int mid = l + r >> 1;
tree[id].l = build(l, mid);
tree[id].r = build(mid + 1, r);
tree[id].w = tree[tree[id].l].w + tree[tree[id].r].w;
return id;
}
void insert(int l, int r, int id1, int id2, int val) {
if (l == r) {
tree[id1].w = sum[l];
tree[id1].l = tree[id1].r = 0;
return;
}
int mid = l + r >> 1;
if (val > mid) {
tree[id1].l = tree[id2].l;
tree[id1].r = ++tot;
insert(mid + 1, r, tot, tree[id2].r, val);
}
if (val <= mid) {
tree[id1].l = ++tot;
tree[id1].r = tree[id2].r;
insert(l, mid, tot, tree[id2].l, val);
}
tree[id1].w = tree[tree[id1].l].w + tree[tree[id1].r].w;
}
void init(int *s, int n) {
pair<int, int> a[MAXN];
for (int i = 1; i <= n; ++i)
a[i] = make_pair(s[i], i);
tot = 0;
sort(a+1, a+1+n); //离散化
int tip = 0;
for (int i = 1; i <= n; ++i) {
if (i == 1 || a[i].fs != a[i-1].fs)
back[++tip] = a[i].fs;
w[a[i].se] = tip;
}
//构造主席树
top[1] = 1;
++sum[w[1]];
build(1, n);
for (int i = 2; i <= n; ++i) {
++sum[w[i]]; //sum[i]: w[x] = i 的个数,sum为大小1~tip的数组
top[i] = ++tot;
insert(1, n, top[i], top[i-1], w[i]);
}
}
int query(int l, int r, int id1, int id2, int k) {
if (l == r) return l;
int le1 = tree[id1].l;
int ri1 = tree[id1].r;
int le2 = tree[id2].l;
int ri2 = tree[id2].r;
int mid = l + r >> 1;
if (tree[le2].w - tree[le1].w < k)
return query(mid+1, r, ri1, ri2, k - (tree[le2].w - tree[le1].w));
else
return query(l, mid, le1, le2, k);
}
}FT;
int in[MAXN], out[MAXN], val[MAXN], w[MAXN];
vector<int> map[MAXN];
void dfs(int x, int fa) {
static int tot = 0;
in[x] = ++tot;
val[tot] = w[x];
for (int i = 0; i < map[x].size(); ++i)
if (map[x][i] != fa)
dfs(map[x][i], x);
out[x] = tot;
}
int id[MAXN];
int main()
{
freopen("in.txt", "r", stdin);
int n, q, a, b, c;
scanf("%d", &n);
for (int i = 1; i <= n; ++i) scanf("%d", &w[i]);
pair<int, int> pir[MAXN];
for (int i = 1; i <= n; ++i)
pir[i] = make_pair(w[i], i);
sort(pir+1, pir+1+n);
for (int i = 1; i <= n; ++i)
id[i] = pir[i].se;
for (int i = 1; i < n; ++i) {
scanf("%d%d", &a, &b);
map[a].push_back(b);
map[b].push_back(a);
}
dfs(1, 0);
FT.init(val, n);
scanf("%d", &q);
for (int i = 1; i <= q; ++i) {
scanf("%d%d", &a, &b);
printf("%d\n", id[FT.query(1, n, FT.top[in[a]-1], FT.top[out[a]], b)]);
}
return 0;
}
主席树模板:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 200000;
#define fs first
#define se second
struct Node{
int l, r, w;
Node(int a=0, int b=0, int c=0):l(a), r(b), w(c){}
};
struct FunctionTree{
Node tree[MAXN*40];
int top[MAXN]; //线段树顶编号
int w[MAXN]; //离散后权重
int sum[MAXN]; //sum[i]为w等于i的数的个数
int back[MAXN]; //i对应的离散前原数
int tot; //主席树结点总数
int build(int l, int r) {
int id = ++tot;
if (l == r) {
tree[id].w = sum[l];
tree[id].l = tree[id].r = 0;
return id;
}
int mid = l + r >> 1;
tree[id].l = build(l, mid);
tree[id].r = build(mid + 1, r);
tree[id].w = tree[tree[id].l].w + tree[tree[id].r].w;
return id;
}
void insert(int l, int r, int id1, int id2, int val) {
if (l == r) {
tree[id1].w = sum[l];
tree[id1].l = tree[id1].r = 0;
return;
}
int mid = l + r >> 1;
if (val > mid) {
tree[id1].l = tree[id2].l;
tree[id1].r = ++tot;
insert(mid + 1, r, tot, tree[id2].r, val);
}
if (val <= mid) {
tree[id1].l = ++tot;
tree[id1].r = tree[id2].r;
insert(l, mid, tot, tree[id2].l, val);
}
tree[id1].w = tree[tree[id1].l].w + tree[tree[id1].r].w;
}
void init(int *s, int n) {
pair<int, int> a[MAXN];
for (int i = 1; i <= n; ++i)
a[i] = make_pair(s[i], i);
tot = 0;
sort(a+1, a+1+n); //离散化
int tip = 0;
for (int i = 1; i <= n; ++i) {
if (i == 1 || a[i].fs != a[i-1].fs)
back[++tip] = a[i].fs;
w[a[i].se] = tip;
}
//构造主席树
top[1] = 1;
++sum[w[1]];
build(1, n);
for (int i = 2; i <= n; ++i) {
++sum[w[i]]; //sum[i]: w[x] = i 的个数,sum为大小1~tip的数组
top[i] = ++tot;
insert(1, n, top[i], top[i-1], w[i]);
}
}
int query(int l, int r, int id1, int id2, int k) {
if (l == r) return back[l];
int le1 = tree[id1].l;
int ri1 = tree[id1].r;
int le2 = tree[id2].l;
int ri2 = tree[id2].r;
int mid = l + r >> 1;
if (tree[le2].w - tree[le1].w < k)
return query(mid+1, r, ri1, ri2, k - (tree[le2].w - tree[le1].w));
else
return query(l, mid, le1, le2, k);
}
}FT;
int a[MAXN];
int main()
{
freopen("in.txt", "r", stdin);
int n, q, x, y, z;
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", &a[i]);
FT.init(a, n);
scanf("%d", &q);
for (int i = 1; i <= q; ++i) {
scanf("%d%d%d", &x, &y, &z);
printf("%d\n", FT.query(1, n, FT.top[x-1], FT.top[y], z));
}
return 0;
}