码一个万能的求树的直径的方法,在树上带负权时候用用,有点麻烦,但也算是dfs两次吧。
我们试想一个问题,如果dfs两次,可能会导致如下现象:
这样的话,两次dfs就错了(从1开始做,直接奔着4去)。
于是我们考虑这样一个事情。如果对于树上的每个节点,我们都求出距离它最远的点以及他们间的距离,然后找出这n个距离中最大的一个,是不是就是直径呀。但是吧,说着都挺好的,但是这个距离哪那么好求啊!最早我想跑个最短路啊啥的,完全没有顾及到这个极特殊且性质贼多的图感受(树也是无向无环连通图)。hockey大神告诉我,这是一个线性的问题,我当时就一脸懵x了,完全是没有想到。
这是一个运用了treedp思想的方法。因为一个点要寻找最远的点,无非两种方法——子树与回溯(好押韵)。我们在子树中寻找是很简单的一个事情,但是回溯呢?我们进入第二次dfs。这时可以利用父节点的max值(按照dfs顺序,父节点的值已经固定了)去更新了。
But,如果父节点的max正是走了你当前这条路径得到了,你把他加上,不就相当于一条路走好几遍吗?这是不被允许的。但是转念一想,我们存个第二大的值(不能再走这条路了,要访问别的儿子),这样的话,我们就可以避开了。在做的时候,顺便存一个maxv,表示下一步经过了哪里。如果父节点不是单纯的直接又下去了,那就最大值去更新,否则次大值来更新。
贴上代码:bzoj1912为例。(全网我可能是唯一一个这样写的???)
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 100010
using namespace std;
struct adj {int to, next, flag;}e[2*N];
int n, k, head[N], cnt = 1, d[N], fa[N], max1[N], maxv1[N], max2[N], maxv2[N];
inline void ins(int x, int y) {e[++cnt].to = y; e[cnt].next = head[x]; e[cnt].flag = 1; head[x] = cnt;}
void dfs(int x, int f) {
fa[x] = f;
for(int i = head[x]; i; i = e[i].next) {
int y = e[i].to; if(f == y) continue;
d[y] = d[x] + 1; dfs(y, x);
}
}
void dfs1(int x, int f) {
for(int i = head[x]; i; i = e[i].next) {
if(e[i].to == f) continue;
dfs1(e[i].to, x);
if(max2[x] < max1[e[i].to] + e[i].flag) {
max2[x] = max1[e[i].to] + e[i].flag; maxv2[x] = e[i].to;
if(max2[x] > max1[x]) {
swap(max2[x], max1[x]);
swap(maxv2[x], maxv1[x]);
}
}
}
}
void dfs2(int x, int f) {
for(int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if(y == f) continue;
if(y == maxv1[x]) {
if(max2[y] < max2[x] + e[i].flag) {
max2[y] = max2[x] + e[i].flag; maxv2[y] = x;
if(max2[y] > max1[y]) {
swap(max2[y], max1[y]);
swap(maxv2[y], maxv1[y]);
}
}
}else {
if(max2[y] < max1[x] + e[i].flag) {
max2[y] = max1[x] + e[i].flag; maxv2[y] = x;
if(max2[y] > max1[y]) {
swap(max2[y], max1[y]);
swap(maxv2[y], maxv1[y]);
}
}
}
dfs2(y, x);
}
}
int main() {
scanf("%d%d", &n, &k);
memset(head, 0, sizeof(head));
for(int i = 1; i < n; ++i) {
int x, y; scanf("%d%d", &x, &y);
ins(x, y); ins(y, x);
}
d[1] = 0; dfs(1, 1);
int u = 1; for(int i = 2; i <= n; ++i) if(d[i] > d[u]) u = i;
d[u] = 0; dfs(u, u);
int v = 1; for(int i = 2; i <= n; ++i) if(d[i] > d[v]) v = i;
int ans = 2 * (n - 1) - d[v] + 1;
if(k == 1) {printf("%d", ans); return 0;}
int tmp = v;
while(tmp != u) {
for(int i = head[tmp]; i; i = e[i].next) {
if(e[i].to == fa[tmp]) {
e[i].flag = e[i^1].flag = -1;
break;
}
}
tmp = fa[tmp];
}
dfs1(u, u); dfs2(u, u);
int d2 = 0;
for(int i = 1; i <= n; ++i) d2 = max(d2, max1[i]);
ans = ans + 1 - d2;
printf("%d", ans);
return 0;
}