树的直径
树的直径是指树上的最长简单路。
任选一点 w
为起点,对树进行搜索,找出离 w
最远的点 u
。
以 u
为起点,再进行搜索,找出离 u
最远的点 v
。则 u
到 v
的路径长度即为树的直径。
简单证明:
如果
w
在直径上,那么u
一定是直径的一个端点。反证:若u
不是端点,设直径的两端为S
与T
,则dist(w, u) > dist(u, T)
且dist(w, u) > dist(u, S)
,则最长路不是S - T
了,与假设矛盾。如果
w
不在直径上,且w
到其距最远点u
的路径与直径一定有一交点c
,那么由上一个证明可知,u
是直径的一个端点。如果
w
到最远点u
的路径与直径没有交点,设直径的两端为S
与T
,那么dist(w, u) > dist(w, c) + dist(c, T)
,推出dist(S, c) + dist(w, u) + dist(w, c) > dist(S, c) + dist(c, T) = dist(S, T)
,最长路不是S - T
与假设矛盾。因此
w
到最远点u
的路径与直径必有交点。S-----------c-----------T | w------u
树的重心
何谓重心
树的重心:找到一个点,以它为整棵树的根,其所有的子树中最大的子树节点数最少,那么这个点就是这棵树的重心,删去重心后,生成的多棵树尽可能平衡。
树的重心有下面几条常见性质:
定义1:找到一个点,其所有的子树中最大的子树节点数最少,那么这个点就是这棵树的重心。
定义2:以这个点为根,那么所有的子树(不算整个树自身)的大小都不超过整个树大小的一半。性质1:树中所有点到某个点的距离和中,到重心的距离和是最小的;如果有两个重心,那么他们的距离和一样。
性质2:把两个树通过一条边相连得到一个新的树,那么新的树的重心在连接原来两个树的重心的路径上。
性质3:把一个树添加或删除一个叶子,那么它的重心最多只移动一条边的距离。
树的重心可以通过简单的两次搜索求出,第一遍搜索求出以 i
为根的每个结点的子结点数量 son[i]
,第二遍搜索找出以 u 为整棵树的根使 max(son[u], n - son[u])
最小的结点 u
。
实际上这两步操作可以在一次遍历中解决。对结点 u
的每一个儿子 v
,递归的处理 v
,求出 son[v]
,然后判断是否是结点数最多的子树,处理完所有子结点后,判断 u
是否为重心。
void dfs(int u, int fa) {
int res = 0; // res 要定义在 dfs 内
son[u] = 1;
for(int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
son[u] += son[v];
res = max(res, son[v]);
}
res = max(res, n - son[u]);
if(res < size) {
ans = u; size = res;
}
}
树的点分治
我们可以发现:
对于一个点,显然只有经过它的路径和不经过它的路径。
我们不考虑不经过它的路径。
经过它的路径的两个端点一定它在的两个子树里。
以它为根统计到它子树节点的距离。这样的单独的距离或两段距离之和一定经过这个根节点。
对于每个点我们都可以这样计算。
可是当树退化成一条链的时候。复杂度就会很高,不断的递归找子树。
所以我们需要按树的重心(因为重心删掉后最大子树节点数最小)来找这些点,这样可以把复杂度控制在
O(nlogn)
。
–
所以点分治的思想就是:
- 找重心把它作为根
- 解决根的路径问题
- 递归子树解决子问题(重复一二步骤)
针对每题不同的只有 cal()
–
size[]
是以删掉 u
后,最大连通块的大小。
size[0]
是整棵子树大小,一开始 rt = 0
,用之后更小 size
值更新 rt
。
sum
是当前子树内的总点数。
temnum
是子树编号。
cnt2
和 temnum
在 cal()
里要记得每次都要清一次空。因为 cnt2
每次是针对不同的重心的,而 temnum
每次针对是不同重心的子树的。
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
struct Edge {
int next, to, w;
}e[N << 1];
struct seg {
int dis, pos;
}seg1[N << 1];
int n, q, sum;
bool ok[10000005];
int head[N], cnt = 0;
void add(int x, int y, int z) {
e[++ cnt].to = y;
e[cnt].next = head[x];
e[cnt].w = z;
head[x] = cnt;
}
int cnt2 = 0;
void add2(int dis, int pos) {
seg1[++ cnt2].dis = dis;
seg1[cnt2].pos = pos;
}
int rt;
int son[N], size[N], vis[N];
void dfs1(int u, int fa) {
son[u] = 1;
size[u] = 0;
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if (v == fa || vis[v]) continue;
son[u] += son[v];
size[u] = max(size[u], son[v]);
}
size[u] = max(size[u], sum - son[u]);
if (size[u] < size[rt]) rt = u;
}
int dis[N];
void dfs2(int u, int fa, int num) {
son[u] = 1;
add2(dis[u], num);
ok[dis[u]] = 1;
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if (v == fa || vis[v]) continue;
dis[v] = dis[u] + e[i].w;
dfs2(v, u, num);
son[u] += son[v];
}
}
void cal(int u) {
int temnum = 0; cnt2 = 0;
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if (vis[v]) continue;
dis[v] = e[i].w;
dfs2(v, u, ++ temnum);
}
for (int i = 1; i < cnt2; i ++)
for (int j = i + 1; j <= cnt2; j ++)
if (seg1[i].pos != seg1[j].pos)
ok[seg1[i].dis + seg1[j].dis] = 1;
}
void solve(int u) {
vis[u] = 1;
cal(u);
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if (vis[v]) continue;
size[0] = sum = son[v];
dfs1(v, rt = 0);
solve(rt);
}
}
int main() {
memset(ok, 0, sizeof(ok));
memset(vis, 0, sizeof(vis));
scanf("%d%d", &n, &q);
for (int i = 1; i < n; i ++) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
size[0] = sum = n;
dfs1(1, 0);
solve(rt);
for (int i = 1; i <= q; i ++) {
int k;
scanf("%d", &k);
if (ok[k]) printf("AYE\n");
else printf("NAY\n");
}
return 0;
}