题意:给出一棵树,树上有n个点,现从中选出m个点,并可以从任意一个点出发,求出到达所有点的最短时间。
思路:首先这些点中距离最远的两个点作为起始和结束点,即m点组成的树里面的直径。然后对于这条直径上的每个点u,用树形DP求出从u出发,计算经过子树中选中的的点并返回的最小时间,最后 ,加上走完直径的时间就是答案。
1、为什么是选择直径作为主线,因为选取其他点时可以证明经过的距离不是最短,总是能找到比他更短的,反证法可以证明。
2、树的直径的求法:从任一节点出发,BFS找到m点中最远的一个点,此为直径的一个端点。然后从该点出发,找到m点中最远的一个点,此即为另外一个直径端点。证明如下:
主要是利用了反证法:
假设 s-t这条路径为树的直径,或者称为树上的最长路
现有结论,从任意一点u出发搜到的最远的点一定是s、t中的一点,然后在从这个最远点开始搜,就可以搜到另一个最长路的端点,即用两遍广搜就可以找出树的最长路
证明:
1 设u为s-t路径上的一点,结论显然成立,否则设搜到的最远点为T则
dis(u,T) >dis(u,s) 且 dis(u,T)>dis(u,t) 则最长路不是s-t了,与假设矛盾
2 设u不为s-t路径上的点
首先明确,假如u走到了s-t路径上的一点,那么接下来的路径肯定都在s-t上了,而且终点为s或t,在1中已经证明过了
所以现在又有两种情况了:
1:u走到了s-t路径上的某点,假设为X,最后肯定走到某个端点,假设是t ,则路径总长度为dis(u,X)+dis(X,t)
2:u走到最远点的路径u-T与s-t无交点,则dis(u-T) >dis(u,X)+dis(X,t);显然,如果这个式子成立,
则dis(u,T)+dis(s,X)+dis(u,X)>dis(s,X)+dis(X,t)=dis(s,t)最长路不是s-t矛盾
3、求直径上每个点的最小值是需要用到father和树形dp。
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <fstream>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <map>
#include <set>
#include <iomanip>
using namespace std;
#define maxn 130000
#define mem(a) memset(a , 0 , sizeof(a))
vector<int> T[maxn];
int n , m;
int st , ed , len;
int point[maxn] , att[maxn]; //记录被攻击的节点 , 将被攻击节点存储
int dp[maxn] , fa[maxn]; //记录经过子树的最短时间 , 记录父亲节点
int dis[maxn];
int vis[maxn];
void bfs1(int pos)
{
mem(vis);
queue<int>q;
vis[pos] = 1;
q.push(pos);
dis[pos] = 0;
st = 0;
dis[st] = -1;
while(!q.empty())
{
int cur = q.front();
q.pop();
for(int i = 0 ; i < T[cur].size() ; i ++)
{
int u = T[cur][i];
if(vis[u]) continue;
vis[u] = 1;
dis[u] = dis[cur] + 1;
if(point[u])
{
if(dis[u] > dis[st]) st = u;
else st = min(st , u);
}
q.push(u);
}
}
}
void bfs2()
{
mem(vis);
queue<int>q;
vis[st] = 1;
q.push(st);
dis[st] = 0;
ed = 0;
dis[ed] = -1;
while(!q.empty())
{
int cur = q.front();
q.pop();
for(int i = 0 ; i < T[cur].size() ; i ++)
{
int u = T[cur][i];
if(vis[u]) continue;
vis[u] = 1;
dis[u] = dis[cur] + 1;
if(point[u])
{
if(dis[u] > dis[ed]) ed = u;
else ed = min(ed , u);
}
q.push(u);
}
}
len = dis[ed];
}
void dfs1(int pos , int f)
{
fa[pos] = f;
for(int i = 0 ; i < T[pos].size() ; i ++)
{
int u = T[pos][i];
if(u == f) continue;
dfs1(u , pos);
}
}
void dfs2(int pos , int fa , int last)
{
dp[pos] = 0;
for(int i = 0 ; i < T[pos].size() ; i ++)
{
int u = T[pos][i];
if(u == fa || u == last) continue;
dfs2(u , pos , last);
if(dp[u] >= 0) dp[pos] = dp[pos] + 2 + dp[u];
}
if(!dp[pos]) dp[pos] = -1;
if(point[pos]) dp[pos] = max(dp[pos] , 0);
}
int main()
{
while(scanf("%d %d" , &n , &m) != EOF)
{
int u , v;
for(int i = 1 ; i < n ; i ++)
{
scanf("%d %d" , &u , &v);
T[u].push_back(v);
T[v].push_back(u);
}
for(int i = 1 ; i <= m ; i ++) scanf("%d" , &u), point[u] = 1 , att[i] = u;
if(m == 1)
{
printf("%d\n%d\n" , att[1] , 0);
continue;
}
// cout << "test" << endl;
bfs1(att[1]);
bfs2(); //;两遍bfs确定树的直径
dfs1(st , 0); //初始化各节点父亲节点
memset(dp , -1 , sizeof(dp));
dfs2(ed , fa[ed] , 0);
int pos = ed ;
int l = 0;
// cout << "test" << endl;
while(st != pos)
{
l = pos ;
pos = fa[pos] ;
dfs2(pos , fa[pos] , l);
}
int ans = len;
pos = ed;
while(pos != 0)
{
ans += max(dp[pos] , 0);
pos = fa[pos];
}
printf("%d\n%d\n" , min(st , ed) , ans);
// for(int i = 1 ; i <= n ; i++) T[i].clear();
}
return 0;
}