题目传送门 : http://acm.hdu.edu.cn/showproblem.php?pid=4822
题目大意
给定一棵树,和树上节点 A A 、、 C C ,若节点到节点 A A 的距离严格小于到、 C C 的距离,那么称被 A A 占有。有若干询问,每次给定一组、 B B 、,问各占领的节点数。
问题分析
若此题三个点减为两个点
A
A
、,那么题目应该不难,我们平分
A−>B
A
−
>
B
的路径,分下奇偶,一半归
A
A
,一半归,应当可以在
O((n+m)log2n)
O
(
(
n
+
m
)
l
o
g
2
n
)
解决。口糊讨论一下做法:首先倍增求
lca
l
c
a
,然后判断一下两个点离最近公共祖先的距离,同样用lca跳到
A−>B
A
−
>
B
路径的终点,最后分下奇偶,按子树的
size
s
i
z
e
算一算就完事了。
但加到三个点,好像一切都变复杂了啊……
咳,别颓了,有办法的。
如果说,我们仅仅只关注
A
A
点占有的点,那么这些点会是上边问题分为,和
A,C
A
,
C
所得点集的交集。这挺显然,但是看起来好像不容易实现啊。毕竟上边两个点的做法可不维护点集。如何维护点集?我们仔细观察发现,每个点集,要么是某颗子树,要么是整棵树减去某颗子树。所以我们自然就想到了
DFS
D
F
S
序。如果按照
DFS
D
F
S
序排序,那么一棵子树中的点在这个序列上连续。
而且此题甚至用不着生成这个序列!因为我们只需要知道这个子树所对应的区间(
[DFN[x],DFN[x]+size[x]−1]
[
D
F
N
[
x
]
,
D
F
N
[
x
]
+
s
i
z
e
[
x
]
−
1
]
),以及是整棵树删除这个子树还是选中这颗子树就可以了!剩下的就大力分类讨论一波就完事了。
听dalao说这题如果拓展到选若干个点,可以用非常dark优秀的虚树做。然而我并不会所以就不介绍了
参考程序
#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <sstream>
#include <algorithm>
using namespace std;
const int MAXN = 100010, MAXLOG = 18;
int n, m, x, y, z;
int lp, f[MAXN], lin[MAXN << 1], nxt[MAXN << 1];
inline void add(int x, int y) { lin[++lp] = y; nxt[lp] = f[x]; f[x] = lp; return; }
int size[MAXN], dfn[MAXN], timeset;
int d[MAXN][MAXLOG];
int deep[MAXN];
void clean_up() {
lp = 0;
memset(deep, 0, sizeof(deep));
memset(f, 0, sizeof(f));
memset(size, 0, sizeof(size));
memset(dfn, 0, sizeof(dfn));
memset(d, 0, sizeof(d));
timeset = 0;
return;
}
void build_tree(int pos, int fa) {
deep[pos] = deep[fa] + 1;
dfn[pos] = ++timeset;
d[pos][0] = fa;
size[pos] = 1;
for(int i = 1; i < MAXLOG; i++) d[pos][i] = d[d[pos][i - 1]][i - 1];
for(int t = f[pos]; t; t = nxt[t]) {
if(lin[t] == fa) continue;
build_tree(lin[t], pos);
size[pos] += size[lin[t]];
}
return;
}
int get_lca(int x, int y) {//求两个点的lca
if(deep[x] < deep[y]) swap(x, y);
for(int i = MAXLOG - 1; i >= 0; i--)
if(deep[d[x][i]] >= deep[y]) x = d[x][i];
if(x == y) return x;
for(int i = MAXLOG - 1; i >= 0; i--)
if(d[x][i] != d[y][i]) {
x = d[x][i]; y = d[y][i];
}
return d[x][0];
}
int jump(int pos, int step) {//跳到pos的第step个父亲
int dep = deep[pos] - step;
for(int i = MAXLOG - 1; i >= 0; i--)
if(deep[d[pos][i]] >= dep) pos = d[pos][i];
return pos;
}
int solve(int x, int y, int z) {
int a = get_lca(x, y);
int b = get_lca(x, z);
int kind1, kind2, aa1, aa2, bb1, bb2;//kind若为1,表示选中这颗子树,2表示选中除这颗子树外的数
//aa1,aa2,bb1,bb2为区间左右端点
//下面这部分一定要注意细节啊啊啊啊啊啊啊
if(deep[x] - deep[a] >= deep[y] - deep[a]) {//求左右端点
kind1 = 1;
aa1 = jump(x, deep[x] - deep[a] - 1 - (deep[x] - deep[y]) / 2);//求要处理的子树的根节点
aa2 = dfn[aa1] + size[aa1] - 1;
aa1 = dfn[aa1];
} else {
kind1 = 2;
aa1 = jump(y, deep[y] - deep[a] - (deep[y] - deep[x] + 1) / 2);
aa2 = dfn[aa1] + size[aa1];
aa1 = dfn[aa1] - 1;
}
if(deep[x] - deep[b] >= deep[z] - deep[b]) {
kind2 = 1;
bb1 = jump(x, deep[x] - deep[b] - 1 - (deep[x] - deep[z]) / 2);
bb2 = dfn[bb1] + size[bb1] - 1;
bb1 = dfn[bb1];
} else {
kind2 = 2;
bb1 = jump(z, deep[z] - deep[b] - (deep[z] - deep[x] + 1) / 2);
bb2 = dfn[bb1] + size[bb1];
bb1 = dfn[bb1] - 1;
}
if(kind1 == 1 && kind2 == 1) {//一波大力分类讨论,一定要注意每一部分是6种而并非4个
if(aa2 < bb1) return 0;
if(bb2 < aa1) return 0;
if(aa1 <= bb1 && bb2 <= aa2) return bb2 - bb1 + 1;
if(aa1 <= bb1 && aa2 < bb2) return aa2 - bb1 + 1;
if(bb1 < aa1 && bb2 <= aa2) return bb2 - aa1 + 1;
if(bb1 < aa1 && aa2 < bb2) return aa2 - aa1 + 1;
}
if(kind1 == 1 && kind2 == 2) {
if(aa2 <= bb1) return aa2 - aa1 + 1;
if(aa1 >= bb2) return aa2 - aa1 + 1;
if(aa1 <= bb1 && bb2 <= aa2) return bb1 - aa1 + 1 + aa2 - bb2 + 1;
if(aa1 <= bb1 && aa2 < bb2) return bb1 - aa1 + 1;
if(bb1 < aa1 && bb2 <= aa2) return aa2 - bb2 + 1;
if(bb1 < aa1 && aa2 < bb2) return 0;
}
if(kind1 == 2 && kind2 == 1) {
if(bb2 <= aa1) return bb2 - bb1 + 1;
if(bb1 >= aa2) return bb2 - bb1 + 1;
if(aa1 < bb1 && bb2 < aa2) return 0;
if(aa1 < bb1 && aa2 <= bb2) return bb2 - aa2 + 1;
if(bb1 <= aa1 && bb2 < aa2) return aa1 - bb1 + 1;
if(bb1 <= aa1 && aa2 <= bb2) return aa1 - bb1 + 1 + bb2 - aa2 + 1;
}
if(kind1 == 2 && kind2 == 2) {
int t = 0;
if(aa1 <= bb1) t = aa1; else t = bb1;
if(aa2 >= bb2) t += n - aa2 + 1; else t += n - bb2 + 1;
if(aa1 >= bb2) t += aa1 - bb2 + 1;
if(bb1 >= aa2) t += bb1 - aa2 + 1;
return t;
}
}
void work() {
clean_up();
scanf("%d", &n);
for(int i = 1; i < n; i++) {
scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
build_tree(1, 1);//建树
scanf("%d", &m);
for(int i = 1; i <= m; i++) {
scanf("%d%d%d", &x, &y, &z);
printf("%d %d %d\n", solve(x, y, z), solve(y, x, z), solve(z, x, y));
}
return;
}
int main() {
int t;
scanf("%d", &t);
for(int i = 1; i <= t; i++) work();
return 0;
}