第一次敲树链剖分,以前都没看懂树链剖分用来干什么的,为多校留个纪念。
传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6162
题意:
给一棵树,每次问某两点之间的最短路径上,能够满足价格在[a,b]这个范围内的价值之和是多少。其实题意很明显是个熟练剖分。但是以前都没有敲过树链剖分,所以并get不到树剖之后再用树维护。比赛的时候只意识到这个感觉像个树套树。。
顺带讲一下树链剖分的知识点。树链剖分就是把一棵树分成n条线段,然后就可以用线段树啊,splay,treap之类的东西来维护了。
树链剖分,主要分为重链和轻链,重链上的节点个数永远大于轻链上的个数。
第一次dfs。记录每一个节点的儿子个数。
第二次dfs,对于儿子数最多的那个儿子节点,作为重链,其他作为轻链,重新定义一个端点继续dfs下去。途中还能记录深度,这条链上的点所在的区间等。
附上树剖代码:
int size[MAXN]; //儿子节点数
int fater[MAXN]; //爸爸
int deep[MAXN]; //深度
int rak[MAXN]; //离散化树的节点
int id[MAXN]; //记录离散化后的rank对应回的节点
int son[MAXN]; //树链剖分的重链儿子
int top[MAXN]; //树链剖分的重链祖宗
int rk;
void init() {
rk = 0;
memset(son, -1, sizeof(son));
memset(fater, 0, sizeof(fater));
memset(size, 0, sizeof(size));
for (int i = 0; i < MAXN; i++) {
vec[i].clear();
}
}
void dfs1(int u, int fa, int dep) {
deep[u] = dep;
fater[u] = fa;
size[u] = 1;
int len = vec[u].size();
for (int i = 0; i < len; i++) {
int v = vec[u][i];
if (v != fa) {
dfs1(v, u, dep + 1);
size[u] += size[v];
if (son[u] == -1 || size[v] > size[son[u]]) {
son[u] = v;
}
}
}
}
void dfs2(int u, int tp) {
top[u] = tp;
rak[u] = ++rk;
id[rk] = u;
if (son[u] == -1) {
return ;
}
dfs2(son[u], tp);
int len = vec[u].size();
for (int i = 0; i < len; i++) {
int v = vec[u][i];
if (v != fater[u] && v != son[u]) {
dfs2(v, v); //轻边中继续制造重边
}
}
}
做完那么多对树的操作之后,我们就能用类似LCA的方法,去访问树上两个节点之间的和什么之类的。
如果两个点不在同一条链上,则让深度大的那个点,直接访问到top祖宗节点,那所需要添加的区间为该点到他祖宗节点的值之和。如果在同一条重链上,那就直接查询rank[x]到rank[y]的区间就好了
这里是用线段树维护的,附上树剖完后查询的代码:
ll treeQuery(int x, int y, int n) {
ll sum = 0;
int p1 = top[x], p2 = top[y]; //判断两个是否在同一重链上,看祖宗节点相不相同
while (p1 != p2) {
if (deep[p1] < deep[p2]) {
swap(p1, p2);
swap(x, y);
}
sum += query(rak[p1], rak[x], 1, n, 1);
x = fater[p1];
p1 = top[x];
}
if (deep[x] > deep[y]) {
swap(x, y);
}
sum += query(rak[x], rak[y], 1, n, 1);
return sum;
}
最后我们再来说这道题,官方题解上说,用树剖后用treap来维护,查询的时候操作treap。treap不单止有splay的机制,还有堆的机制,可以通过插入删除节点来查询到当时插入时的儿子节点区间,然后减一下区间和就好了,这是在线就可以完成的。
而有些题解上说的,树剖后,在线用线段树去维护最大值,最小值和区间和,直到找到满足最大值最小值在所给的[a,b]区间为止。然而,这样的时间复杂度是不对的,当成链状时,每次询问头和尾,且区间一直最小值大于线段树上最小值,最大值小于线段树上最大值,那么每次都需要跑完整个线段树(听说题目太水暴力LCA直接跑都可以过。。)
这里给出离线处理结果的方法,将所有询问的左区间排个序,每次插入线段树的时候,保证左区间的值小于询问的左区间的值,然后更新所在树剖链上的点,查询的时候查询价值所需区间,左边处理一次,右区间处理一次,然后要答案的时候直接减掉就好了。
代码有点长,没有标程那种treap优雅。
/*
@resources: hdu 6162
@date: 2017-08-24
@author: QuanQqqqq
@algorithm: 树链剖分 + segment tree
*/
#include <bits/stdc++.h>
#define MAXN 100005
#define ll long long
#define lson l, mid, root << 1
#define rson mid + 1, r, root << 1 | 1
using namespace std;
struct node {
ll l, r;
int id, x, y;
};
node val[MAXN], qsn[MAXN];
ll tree[MAXN << 2];
ll ansl[MAXN], ansr[MAXN];
int size[MAXN]; //儿子节点数
int fater[MAXN]; //爸爸
int deep[MAXN]; //深度
int rak[MAXN]; //离散化树的节点
int id[MAXN]; //记录离散化后的rank对应回的节点
int son[MAXN]; //树链剖分的重链儿子
int top[MAXN]; //树链剖分的重链祖宗
int rk;
vector<int> vec[MAXN];
void addEdge(int u, int v) {
vec[u].push_back(v);
vec[v].push_back(u);
}
void init() {
rk = 0;
memset(son, -1, sizeof(son));
memset(fater, 0, sizeof(fater));
memset(size, 0, sizeof(size));
for (int i = 0; i < MAXN; i++) {
vec[i].clear();
}
}
int cmpl(node a, node b) {
return a.l < b.l;
}
int cmpr(node a, node b) {
return a.r < b.r;
}
void dfs1(int u, int fa, int dep) {
deep[u] = dep;
fater[u] = fa;
size[u] = 1;
int len = vec[u].size();
for (int i = 0; i < len; i++) {
int v = vec[u][i];
if (v != fa) {
dfs1(v, u, dep + 1);
size[u] += size[v];
if (son[u] == -1 || size[v] > size[son[u]]) {
son[u] = v;
}
}
}
}
void dfs2(int u, int tp) {
top[u] = tp;
rak[u] = ++rk;
id[rk] = u;
if (son[u] == -1) {
return ;
}
dfs2(son[u], tp);
int len = vec[u].size();
for (int i = 0; i < len; i++) {
int v = vec[u][i];
if (v != fater[u] && v != son[u]) {
dfs2(v, v); //轻边中继续制造重边
}
}
}
void push_down(int root) {
tree[root] = tree[root << 1] + tree[root << 1 | 1];
}
void update(int need, int val, int l, int r, int root) {
if (l == r) {
tree[root] += val;
return ;
}
int mid = l + r >> 1;
if (need <= mid) {
update(need, val, lson);
} else {
update(need, val, rson);
}
push_down(root);
}
ll query(int L, int R, int l, int r, int root) {
if (L <= l && r <= R) {
return tree[root];
}
int mid = l + r >> 1;
ll sum = 0;
if (mid >= L) {
sum += query(L, R, lson);
}
if (mid < R) {
sum += query(L, R, rson);
}
return sum;
}
ll treeQuery(int x, int y, int n) {
ll sum = 0;
int p1 = top[x], p2 = top[y]; //判断两个是否在同一重链上的
while (p1 != p2) {
if (deep[p1] < deep[p2]) {
swap(p1, p2);
swap(x, y);
}
sum += query(rak[p1], rak[x], 1, n, 1);
x = fater[p1];
p1 = top[x];
}
if (deep[x] > deep[y]) {
swap(x, y);
}
sum += query(rak[x], rak[y], 1, n, 1);
return sum;
}
int main() {
int n, q, u, v;
while (~scanf("%d %d", &n, &q)) {
init();
for (int i = 1; i <= n; i++) {
scanf("%lld", &val[i].r);
val[i].id = i;
}
for (int i = 1; i <= n - 1; i++) {
scanf("%d %d", &u, &v);
addEdge(u, v);
}
dfs1(1, -1, 1);
dfs2(1, 1);
sort(val + 1, val + 1 + n, cmpr);
for (int i = 1; i <= q; i++) {
scanf("%d %d %lld %lld", &qsn[i].x, &qsn[i].y, &qsn[i].l, &qsn[i].r);
qsn[i].id = i;
}
memset(tree, 0, sizeof(tree));
sort(qsn + 1, qsn + q + 1, cmpl);
int j = 1;
for (int i = 1; i <= q; i++) {
while (val[j].r < qsn[i].l && j <= n) {
update(rak[val[j].id], val[j].r, 1, n, 1);
j++;
}
ansl[qsn[i].id] = treeQuery(qsn[i].x, qsn[i].y, n);
}
memset(tree, 0, sizeof(tree));
sort(qsn + 1, qsn + q + 1, cmpr);
j = 1;
for (int i = 1; i <= q; i++) {
while (val[j].r <= qsn[i].r && j <= n) {
update(rak[val[j].id], val[j].r, 1, n, 1);
j++;
}
ansr[qsn[i].id] = treeQuery(qsn[i].x, qsn[i].y, n);
}
for (int i = 1; i <= q; i++) {
if (i != 1) {
printf(" ");
}
printf("%lld", ansr[i] - ansl[i]);
}
puts("");
}
}