首先我们可以搞出这
K
个点的一棵生成树,记这棵生成树的边权和为
假设每次都要返回出发点
x
,那么这里分两种情况讨论:
如果
否则,答案为
2×sum+mindis(x,tree)
那么如果不需要返回的话,答案在以上基础上上要减去
maxdis(x,i)(i belongs to the tree)
,最远点一定在树的直径上,所以大力BFS就行了。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int N=500005;
int n,k,S,T,cnt=1,pos[N],head[N],from[N],list[N<<1],next[N<<1],key[N<<1];
long long ans,dis[N],dis1[N],dis2[N];
bool vis[N],in[N];
inline int read()
{
int a=0,f=1; char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();}
while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();}
return a*f;
}
inline void insert(int x,int y,int z)
{
next[++cnt]=head[x];
head[x]=cnt;
list[cnt]=y;
key[cnt]=z;
}
inline void BFS1(int st)
{
memset(vis,0,sizeof(vis));
vis[st]=1; from[st]=0;
queue<int> q;
q.push(st);
while (!q.empty())
{
int x=q.front(); q.pop();
for (int i=head[x];i;i=next[i])
if (!vis[list[i]])
{
vis[list[i]]=1;
from[list[i]]=i;
q.push(list[i]);
}
}
}
inline int BFS2(int st)
{
memset(vis,0,sizeof(vis));
vis[st]=1; dis[st]=0;
int mx=0;
queue<int> q;
q.push(st);
while (!q.empty())
{
int x=q.front(); q.pop();
for (int i=head[x];i;i=next[i])
if (in[list[i]]&&!vis[list[i]])
{
vis[list[i]]=1;
dis[list[i]]=dis[x]+key[i];
if (dis[list[i]]>dis[mx]) mx=list[i];
q.push(list[i]);
}
}
return mx;
}
inline void BFS3(int st,long long *dis)
{
memset(vis,0,sizeof(vis));
vis[st]=1; dis[st]=0;
queue<int> q;
q.push(st);
while (!q.empty())
{
int x=q.front(); q.pop();
for (int i=head[x];i;i=next[i])
if (!vis[list[i]])
{
vis[list[i]]=1;
dis[list[i]]=dis[x]+key[i];
q.push(list[i]);
}
}
}
inline void BFS4()
{
memset(vis,0,sizeof(vis));
queue<int> q;
for (int i=1;i<=n;i++)
if (in[i]) q.push(i),vis[i]=1,dis[i]=0;
while (!q.empty())
{
int x=q.front(); q.pop();
for (int i=head[x];i;i=next[i])
if (!vis[list[i]])
{
vis[list[i]]=1;
dis[list[i]]=dis[x]+key[i];
q.push(list[i]);
}
}
}
int main()
{
n=read(); k=read();
for (int i=1;i<n;i++)
{
int u=read(),v=read(),w=read();
insert(u,v,w); insert(v,u,w);
}
for (int i=1;i<=k;i++) pos[i]=read();
BFS1(pos[1]);
for (int i=1;i<=k;i++)
for (int j=pos[i];!in[j]&&j;j=list[from[j]^1])
ans+=key[from[j]],in[j]=1;
ans<<=1;
S=BFS2(pos[1]); T=BFS2(S);
BFS3(S,dis1); BFS3(T,dis2);
BFS4();
for (int i=1;i<=n;i++)
if (in[i]) printf("%lld\n",ans-max(dis1[i],dis2[i]));
else printf("%lld\n",ans+(dis[i]<<1)-max(dis1[i],dis2[i]));
return 0;
}