题意:
给一颗树, 每条边的权值都是1, 问选取两个点, 使得树上所有点到最近的一个点的距离的最大值最小是多少。
思路:
先二分答案, 然后判断最大距离能不能小于k。
先选任意一个点把树转化成有根树, 然后选取一个深度最大的点u, 再求出这个点的k级祖先p, 然后再从p做bfs, 把到距离小于等于p的节点做标记, 然后再找一个深度最大并且没有被标记的节点, 然后在求出它的k级祖先, 然后再bfs, 最后看是不是所有点都被标记, 如果都被标记就表示k可行。
同步赛的时候贪心了下, 觉得好像行就敲了, 结果多输出了一个临时变量,逗比了好久。
后来想下好像可以证明。 假设这样贪心取点找到了深度最大的点是u,k级祖先是p, 如果存在一种更优的方案v,那么因为要覆盖u, 所以v一定是在p的子树中, 而p的子树中的所有点p都能覆盖, 也就是v能覆盖的所有点p都能覆盖, 所以选p更优。
貌似用树的直径可以做成O(N)。。。
还有就是以后能bfs搞的坚决不能用dfs。。。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
#define mxn 220010
#define eps 1e-10
#define mxe 440040
#define inf 1e20
#define LL long long
#pragma comment(linker, "/STACK:1024000000,1024000000")
int fst[mxn], nxt[mxe], to[mxe], e;
void init() {
memset(fst, -1, sizeof(fst));
e = 0;
}
void add(int u, int v) {
to[e] = v, nxt[e] = fst[u], fst[u] = e++;
}
int dep[mxn], fa[mxn];
int n;
bool vis[mxn][2];
int id, ans1, ans2;
int q[mxn], dis[mxn];
void bfs() {
id = 1;
dep[1] = 0;
int head = 0, tail = 1;
q[0] = 1;
fa[1] = -1;
while(head < tail) {
int u = q[head++];
for(int i = fst[u]; ~i; i = nxt[i]) {
int v = to[i];
if(v == fa[u]) continue;
dep[v] = dep[u] + 1;
if(dep[v] > dep[id]) id = v;
fa[v] = u;
q[tail++] = v;
}
}
}
int kfa(int u, int k) {
while(u != -1 && k) {
k--;
u = fa[u];
}
return u;
}
void markIt(int u, int k, int t) {
int head = 0, tail = 1;
q[0] = u;
dis[0] = k;
vis[u][t] = 1;
while(head < tail) {
int x = q[head];
int dd = dis[head++];
if(dd == 0) continue;
for(int i = fst[x]; ~i; i = nxt[i]) {
int v = to[i];
if(vis[v][t]) continue;
vis[v][t] = 1;
q[tail] = v, dis[tail++] = dd - 1;
}
}
}
bool check(int k) {
ans1 = ans2 = -1;
memset(vis, 0, sizeof(vis));
int p = kfa(id, k);
if(p == -1) {
ans1 = 1;
ans2 = 2;
return 1;
}
ans1 = p;
markIt(p, k, 0);
int u = -1;
for(int i = 1; i <= n; ++i) {
if(vis[i][0]) continue;
if(u == -1 || dep[i] > dep[u])
u = i;
}
if(u == -1) {
for(int i = 1; i <= n; ++i)
if(ans1 != i) {
ans2 = i;
break;
}
return 1;
}
p = kfa(u, k);
if(p == -1) {
if(ans1 == 1)
ans2 = 2;
else
ans2 = 1;
return 1;
}
ans2 = p;
markIt(p, k, 1);
for(int i = 1; i <= n; ++i)
if(!vis[i][0] && !vis[i][1])
return 0;
return 1;
}
int main() {
int cas;
scanf("%d", &cas);
while(cas--) {
scanf("%d", &n);
init();
for(int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
bfs();
// printf("%d\n", id);
int l = 0, r = n;
while(l < r) {
int mid = (l + r) / 2;
if(check(mid))
r = mid;
else
l = mid + 1;
}
check(l);
printf("%d ", l);
printf("%d %d\n", ans1, ans2);
}
return 0;
}