题目大意:
给你一棵带权树,和一些选定的点。一个人从$i$点出发,要开车走遍所有选定的点(不必回到起点),要你分别输出$i=1\sim n$时,这个人走的最短的方案的长度。
解题思路:
首先把虚树构建出来(找出所有在这棵虚树中的节点即可),DFS一遍即可。
然后我们先假设它要回到起点,那么对于每个在虚树上的节点,它要走过的距离就是虚树上所有边的权值和的两倍,这个是显然的。
但是它不用回到起点,那我们只要减去起点到虚树上离它最远的点的距离即可(也就是少走最长的那条路,建立虚树的时候找出即可)。
而树上一个点离它最远的节点一定是树的直径的两个端点之一,因此我们求出虚树直径的两个端点,然后每次找一个离起点最远的端点,减去它到起点的距离即为答案(两遍DFS求虚树直径的顶点)。
然后对于不是虚树上的节点,只要找到离它最近的虚树节点,求这个节点的答案加上该节点到那个虚树节点的距离即可,一遍DFS即可找到所有非虚树节点的最近的虚树节点。
最后求答案即可,算两点间的距离时计算一下LCA即可
时间复杂度$O(n\log_2 n)$。
C++ Code:
#include<bits/stdc++.h>
#define N 500005
#define ll long long
#define Dis(a,b) (dis[a]+dis[b]-(dis[lca(a,b)]<<1))
int n,k,cnt=0,head[N],rt,nxtnd[N],L,R,fa[N][21],dep[N];
ll sum=0,dis[N],d[N];
struct edge{
int to,dis,nxt;
}e[N<<1];
inline ll max(ll a,ll b){return a<b?b:a;}
inline int readint(){
char c=getchar();
for(;!isdigit(c);c=getchar());
int d=0;
for(;isdigit(c);c=getchar())
d=(d<<3)+(d<<1)+(c^'0');
return d;
}
void dfs(int now,int pr,int pre){
if(nxtnd[now]==-1)pr=now;
for(int i=head[now];i;i=e[i].nxt)
if(e[i].to!=pre){
dis[e[i].to]=dis[now]+e[i].dis;
fa[e[i].to][0]=now;
dep[e[i].to]=dep[now]+1;
dfs(e[i].to,pr,now);
if(nxtnd[e[i].to]==-1)
nxtnd[now]=-1,sum+=e[i].dis;
}
if(!nxtnd[now])nxtnd[now]=pr;
}
void dfs2(int now,int pre){
for(int i=head[now];i;i=e[i].nxt)
if(e[i].to!=pre&&nxtnd[e[i].to]==-1){
d[e[i].to]=d[now]+e[i].dis;
dfs2(e[i].to,now);
}
}
int lca(int x,int y){
if(dep[x]<dep[y])x^=y^=x^=y;
for(int i=20;i>=0;--i)
if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
if(x==y)return x;
for(int i=20;i>=0;--i)
if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void dfs0(int now,int pre){
for(int i=head[now];i;i=e[i].nxt)
if(e[i].to!=pre){
if(nxtnd[e[i].to]!=-1){
if(nxtnd[now]==-1)nxtnd[e[i].to]=now;else
nxtnd[e[i].to]=nxtnd[now];
}
dfs0(e[i].to,now);
}
}
int main(){
n=readint(),k=readint();
memset(head,0,sizeof head);
memset(dep,0,sizeof dep);
for(int i=1;i<n;++i){
int x=readint(),y=readint(),z=readint();
e[++cnt]=(edge){y,z,head[x]};
head[x]=cnt;
e[++cnt]=(edge){x,z,head[y]};
head[y]=cnt;
}
memset(nxtnd,0,sizeof nxtnd);
memset(dis,0,sizeof dis);
nxtnd[rt=readint()]=-1;
dep[rt]=1;
for(int i=1;i<k;++i)nxtnd[readint()]=-1;
for(int i=head[rt];i;i=e[i].nxt){
dep[e[i].to]=2;
fa[e[i].to][0]=rt;
dis[e[i].to]=e[i].dis;
dfs(e[i].to,rt,rt);
if(nxtnd[e[i].to]==-1)sum+=e[i].dis;
}
dfs0(rt,0);
memset(d,0,sizeof d);
dfs2(rt,0);
L=R=rt;
for(int i=1;i<=n;++i)
if(nxtnd[i]==-1&&d[i]>d[L])L=i;
memset(d,0,sizeof d);
dfs2(L,0);
for(int i=1;i<=n;++i)
if(nxtnd[i]==-1&&d[i]>d[R])R=i;
for(int j=1;j<21;++j)
if(1<<j<=n)
for(int i=1;i<=n;++i)
fa[i][j]=fa[fa[i][j-1]][j-1];else break;
sum<<=1;
for(int i=1;i<=n;++i){
if(nxtnd[i]==-1){
ll ans=sum-max(Dis(i,L),Dis(i,R));
printf("%lld\n",ans);
}else{
ll ans=sum+Dis(i,nxtnd[i])-max(Dis(nxtnd[i],L),Dis(nxtnd[i],R));
printf("%lld\n",ans);
}
}
return 0;
}