个人对于可持久化权值线段树一些理解
主要的解释大部分在代码块里面,代码是主席树的模板题,第k小数的部分代码;
修改函数
跟线段树不同的是,主席树递归的不再是左右边界,而是左右子节点,每次修改都是从根节点生成一条新的路径,与不牵涉修改的点相连
int insert(int p, int l, int r, int x)
{
int q = ++idx; //每次修改操作都会生成一个新的根节点
tr[q] = tr[p]; //将原来的节点复制下来
if (l == r)
{
tr[q].cnt++;
return q;
}
int mid = l + r >> 1;
if (x <= mid)
tr[q].l = insert(tr[p].l, l, mid, x);
else
tr[q].r = insert(tr[p].r, mid + 1, r, x);
//把从根节点到要修改的点的路径上所有的点都跟更新一遍,相当于生成了一条新的路径,但是这条路径仍然与不牵涉修改的点相连
tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt; //更新父节点
return q;
}
查询函数
主席树查询第k小数的主要思想就是不同历史版本之间的前缀和作差;
每次查询到一个结点之后,首先看新版本的树与旧版本的树左节点的差值,
根据差值决定下一步是向左递归还是向右递归
int query(int q, int p, int l, int r, int k) // p为新版本,q为老版本
{
if (l == r)
return r;
int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt;
//先求出左边的点的数量,也就是权值
int mid = l + r >> 1;
if (k <= cnt)
return query(tr[q].l, tr[p].l, l, mid, k); //向左边递归,数量已经够了
else
return query(tr[q].r, tr[p].r, mid + 1, r, k - cnt); //向右边递归(因为查出的元素还不够,把左边减下去在继续从右边找);
}
完整代码
#include <bits/stdc++.h>
using namespace std;
const int N = 100010, M = 10010;
int n, m;
int a[N];
int root[N], idx;
vector<int> nums;
struct node
{
int l, r;
int cnt;
} tr[N * 20];
int find(int x)
{
return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}
int build(int l, int r)
{
int p = ++idx;
if (l == r)
return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
return p;
}
int insert(int p, int l, int r, int x)
{
int q = ++idx; //每次修改操作都会生成一个新的根节点
tr[q] = tr[p]; //将原来的节点复制下来
if (l == r)
{
tr[q].cnt++;
return q;
}
int mid = l + r >> 1;
if (x <= mid)
tr[q].l = insert(tr[p].l, l, mid, x);
else
tr[q].r = insert(tr[p].r, mid + 1, r, x);
//把从根节点到要修改的点的路径上所有的点都跟更新一遍,相当于生成了一条新的路径,但是这条路径仍然与不牵涉修改的点相连
tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt; //更新父节点
return q;
}
int query(int q, int p, int l, int r, int k) // p为新版本,q为老版本
{
if (l == r)
return r;
int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt;
//先求出左边的点的数量,也就是权值
int mid = l + r >> 1;
if (k <= cnt)
return query(tr[q].l, tr[p].l, l, mid, k); //向左边递归,数量已经够了
else
return query(tr[q].r, tr[p].r, mid + 1, r, k - cnt); //向右边递归(因为查出的元素还不够,把左边减下去在继续从右边找);
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
nums.push_back(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
root[0] = build(0, nums.size() - 1);
for (int i = 1; i <= n; i++)
{
root[i] = insert(root[i - 1], 0, nums.size() - 1, find(a[i])); // 根据老版本去更新新版本的根节点。
}
while (m--)
{
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", nums[query(root[r], root[l - 1], 0, nums.size() - 1, k)]);
}
return 0;
}