以下所说的“关键点”指出租车必经的点。
先随便从一个关键点开始 d f s dfs dfs,找虚树上的直径。如果一个点的子树中有关键点,就把这个点也标为关键点。因为要到达它子树的关键点就必须经过该点。
如果出发点在关键点,答案就是虚树上所有边的长度乘2减去它到虚树的一个点上的最长距离。而这个点必为直径两个端点之一。可以这样想:选中虚树上的一条链,从这条链的一端出发走向另一端,对于每一个结点,要遍历完它的子树再回到这个结点。走完之后,链上所有结点的子树都走了两遍,这条链只走了一遍。那么让这条链最长就能使答案最小。
对于不在关键点上的点,只需要加上它到离它最近的一个关键点的距离。因为这个点不是关键点,所以它的子树中不存在关键点。那么出租车就必须要往父亲方向走到关键点上。
所以思路大致分为几步:
一、标记出所有关键点,找到直径一端。
二、
d
f
s
dfs
dfs出所有点到当前root的距离,同时找到直径另一端。
三、得到每个点到虚树的最远距离。得到每个非关键点到离它最近的关键点的距离。
四、答案就是
(
s
u
m
∗
2
−
d
i
s
[
i
]
[
0
]
+
d
i
s
[
i
]
[
1
]
∗
2
)
(sum*2-dis[i][0]+dis[i][1]*2)
(sum∗2−dis[i][0]+dis[i][1]∗2),其中
s
u
m
sum
sum是虚树边权和,
d
i
s
[
i
]
[
0
]
dis[i][0]
dis[i][0]是该点到虚树最远距离,
d
i
s
[
i
]
[
1
]
∗
2
dis[i][1]*2
dis[i][1]∗2是它到离它最近的关键点的距离。
由于非关键点中
d
i
s
[
i
]
[
0
]
dis[i][0]
dis[i][0]多减了一次
d
i
s
[
i
]
[
1
]
dis[i][1]
dis[i][1],所以要乘2。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=5e5+10;
int n,k,u,v,key[maxn],siz[maxn],root;
int Head[maxn],Next[maxn<<1],V[maxn<<1],cnt=0;
ll W[maxn<<1],dis[maxn][2],w,sum=0;
inline void add(int u,int v,ll w){Next[++cnt]=Head[u],V[cnt]=v,W[cnt]=w,Head[u]=cnt;}
inline int read(){
int x=0;char ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x;
}
inline void print(ll x){
if(x>9) print(x/10);
putchar(x%10+'0');
}
inline void dfs1(int u,int fa){
siz[u]=key[u];
for(int i=Head[u];i;i=Next[i]) if(V[i]!=fa){
dis[V[i]][0]=dis[u][0]+W[i];
if(key[V[i]]&&dis[V[i]][0]>dis[root][0]) root=V[i];
dfs1(V[i],u),siz[u]+=siz[V[i]];
if(siz[V[i]]) sum+=W[i];
}key[u]=(siz[u]!=0);
}
inline void dfs2(int u,int fa,int id){
for(int i=Head[u];i;i=Next[i]) if(V[i]!=fa){
dis[V[i]][id]=dis[u][id]+W[i];
if(key[V[i]]&&dis[V[i]][id]>dis[root][id]) root=V[i];
dfs2(V[i],u,id);
}
}
inline void dfs3(int u,int fa){
for(int i=Head[u];i;i=Next[i]) if(V[i]!=fa)
dis[V[i]][1]=(key[V[i]])?(0):(dis[u][1]+W[i]),dfs3(V[i],u);
}
int main(){
n=read(),k=read();
for(int i=1;i<n;++i) u=read(),v=read(),w=(ll)read(),add(u,v,w),add(v,u,w);
for(int i=1;i<=k;++i) key[root=read()]=1;
dfs1(root,0);for(int i=1;i<=n;++i) dis[i][0]=0;
dfs2(root,0,0),dfs2(root,0,1);
for(int i=1;i<=n;++i) dis[i][0]=max(dis[i][0],dis[i][1]),dis[i][1]=0;
dfs3(root,0);
for(int i=1;i<=n;++i) print(sum*2-dis[i][0]+dis[i][1]*2),putchar(10);
}
调试发现的错误:
inline void dfs1(int u,int fa){
siz[u]=key[u];
for(int i=Head[u];i;i=Next[i]) if(V[i]!=fa){
dis[V[i]][0]=dis[u][0]+W[i];
if(key[V[i]]&&dis[V[i]][0]>dis[root][0]) root=V[i];
dfs1(V[i],u),siz[u]+=siz[V[i]];
if(siz[V[i]]) sum+=W[i];
}key[u]=(siz[u]!=0),dis[u][0]=0;
}
最开始写的时候,我把
d
i
s
[
u
]
[
0
]
dis[u][0]
dis[u][0]在
d
f
s
dfs
dfs里面清空了。这样是不对的。
想了一下,发现如果这个点是直径的一端
r
o
o
t
root
root,那么把他的子树搜完之后,
d
i
s
[
r
o
o
t
]
[
0
]
dis[root][0]
dis[root][0]就变成0。那么这个原本正确的
r
o
o
t
root
root就会被后面搜到的点替换掉,答案就不对了。