题目:
题目链接:https://jzoj.net/senior/#main/show/3919
给出一棵树以及若干个标记点,树有边权,求分别从每一个点出发经过所有的标记点所需的最小边权和。
思路:
首先我们设
T
T
T为能包含所有标记点的最小的树,
s
i
z
e
T
sizeT
sizeT为
T
T
T的边权和。
考虑所有
x
∈
T
x\in T
x∈T的点
x
x
x,如果我们要求最终需要回到出发点
(
x
)
(x)
(x),那么显然答案就是
2
×
s
i
z
e
T
2\times sizeT
2×sizeT。因为
T
T
T的每一条边都要正好经过两遍。
但如果一个点
x
∉
T
x∉T
x∈/T,那么答案就是
2
(
s
i
z
e
T
+
d
i
s
)
2(sizeT+dis)
2(sizeT+dis)。其中
d
i
s
dis
dis表示
x
x
x与
T
T
T的距离。原因也是每一条边要经过两次。
但是我们到达最后一个标记点后可以不用回到出发点。所以为了使边权和最小,肯定最后去距离
x
x
x最远的点,这样
x
x
x与这个点的路径就只要走一次,而这个路径肯定能省就省。
那么可以用一棵线段树储存每一个标记点到现在
d
f
s
dfs
dfs到的点的距离。用
所以第一遍
d
f
s
dfs
dfs求出每一个特殊点的
i
d
,
r
k
id,rk
id,rk分别表示这是第几个遍历到的标记点,以及第
r
k
rk
rk个遍历到的标记点是哪个。同时求出每一个点为根的子树内有多少个标记点,记为
s
i
z
e
size
size
第二遍
d
f
s
dfs
dfs就算出每一个点到
T
T
T的距离
d
i
s
dis
dis,以及
s
i
z
e
T
sizeT
sizeT
第三遍
d
f
s
dfs
dfs就换根,同时维护线段树。
时间复杂度
O
(
n
log
n
)
O(n\log n)
O(nlogn)
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=500010;
int n,m,tot,root,head[N],id[N],rk[N],size[N];
ll sum,ans[N],dis[N],dep[N];
bool flag[N];
struct edge
{
int next,to,dis;
}e[N*2];
struct Treenode
{
int l,r;
ll maxn,lazy;
};
struct Tree
{
Treenode tree[N*4];
void pushup(int x)
{
tree[x].maxn=max(tree[x*2].maxn,tree[x*2+1].maxn);
}
void pushdown(int x)
{
if (tree[x].lazy)
{
tree[x*2].lazy+=tree[x].lazy;
tree[x*2+1].lazy+=tree[x].lazy;
tree[x*2].maxn+=tree[x].lazy;
tree[x*2+1].maxn+=tree[x].lazy;
tree[x].lazy=0;
}
}
void build(int x,int l,int r)
{
tree[x].l=l; tree[x].r=r;
if (l==r)
{
tree[x].maxn=dep[rk[l]];
return;
}
int mid=(l+r)>>1;
build(x*2,l,mid); build(x*2+1,mid+1,r);
pushup(x);
}
void update(int x,int l,int r,ll val)
{
if (l>r) return;
if (tree[x].l==l && tree[x].r==r)
{
tree[x].maxn+=val;
tree[x].lazy+=val;
return;
}
pushdown(x);
int mid=(tree[x].l+tree[x].r)>>1;
if (r<=mid) update(x*2,l,r,val);
else if (l>mid) update(x*2+1,l,r,val);
else update(x*2,l,mid,val),update(x*2+1,mid+1,r,val);
pushup(x);
}
}Tree;
void add(int from,int to,int dis)
{
e[++tot].to=to;
e[tot].dis=dis;
e[tot].next=head[from];
head[from]=tot;
}
int dfs1(int x,int fa)
{
if (flag[x]) id[x]=++tot,rk[tot]=x,size[x]=1;
else id[x]=tot+1;
for (int i=head[x];~i;i=e[i].next)
if (e[i].to!=fa)
{
dep[e[i].to]=dep[x]+e[i].dis;
size[x]+=dfs1(e[i].to,x);
}
return size[x];
}
void dfs2(int x,int fa)
{
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
if (!size[v]) dis[v]=dis[x]+e[i].dis;
else sum+=e[i].dis;
dfs2(v,x);
}
}
}
void dfs3(int x,int fa)
{
ans[x]=sum*2+dis[x]*2-Tree.tree[1].maxn;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
Tree.update(1,id[v],id[v]+size[v]-1,-e[i].dis*2);
Tree.update(1,1,m,e[i].dis);
dfs3(v,x);
Tree.update(1,1,m,-e[i].dis);
Tree.update(1,id[v],id[v]+size[v]-1,e[i].dis*2);
}
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for (int i=1,x,y,z;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z); add(y,x,z);
}
for (int i=1,x;i<=m;i++)
{
scanf("%d",&x);
flag[x]=1;
root=x;
}
tot=0;
dfs1(root,0);
Tree.build(1,1,m);
dfs2(root,0);
dfs3(root,0);
for (int i=1;i<=n;i++)
printf("%lld\n",ans[i]);
return 0;
}