题面
给定一颗有 n n n个节点的无根树,每一条边 e e e有一个经过的时间 t ( e ) t(e) t(e),树上有 K K K个关键节点,现在对于每一个节点 u u u,你需要回答下面的问题:
从 u u u出发,开一辆车,要求到达所有关键节点,最终不必回到节点 u u u,问完成的最短时间。
数据范围
1 ≤ k ≤ n ≤ 5 × 1 0 5 ∀ e ∈ E , 1 ≤ t ( e ) ≤ 1 0 9 1 \leq k \leq n \leq 5 \times 10^5\\ \forall e\in E,1 \leq t(e) \leq 10^9 1≤k≤n≤5×105∀e∈E,1≤t(e)≤109
题解
容易确定思考方向:换根dp。
有一个比较好用的套路,对于一个物体在树上移动,要求到达树上的某些点,可以采用规定这个物体是否回到了根节点的方法。当然了,还有另一种相近的情况是:一个点走到某一个位置后消失了,但是根节点处又会立刻出现新的物体。这样整棵树上上依然只有一个点,也可以思考一下使用类似的方法。
回到本题,可以发现因为只有一个物体,那么就只能不会到根节点最多一次(停在了最后一个关键节点),而且我们很容易就能得到最后停在了哪一个关键节点,只需要找到从根节点 u u u处最远的关键节点即可。
设
f
(
u
)
f(u)
f(u)表示从节点
u
u
u处出发,遍历了所有子树
T
(
u
)
T(u)
T(u)中的关键节点之后,物体最终回到了节点
u
u
u的时间,那么可以得到:
f
(
u
)
=
∑
v
∈
s
o
n
(
u
)
(
f
(
v
)
+
t
(
u
,
v
)
×
2
)
×
[
T
h
e
r
e
a
r
e
k
e
y
n
o
d
e
s
i
n
T
(
v
)
]
f(u)=\sum\limits_{v \in \mathrm{son}(u)} (f(v)+t(u,v) \times 2)\times [\mathrm{There~are~key~nodes ~ in~T(v)}]
f(u)=v∈son(u)∑(f(v)+t(u,v)×2)×[There are key nodes in T(v)]
特别的,假如
T
(
u
)
T(u)
T(u)中没有关键节点,那么
f
(
u
)
=
0
f(u)=0
f(u)=0
设 m x d ( u ) , s m d ( u ) mxd(u), smd(u) mxd(u),smd(u)分别表示从节点 u u u出发,到达的最远的和次远的 T ( u ) T(u) T(u)中的关键节点的距离。
接下来是从 u u u到 f a ( u ) fa(u) fa(u)的部分。
按照套路,设 g ( u ) g(u) g(u)表示从 u u u到 T ( u ) T(u) T(u)以外的所有关键节点,最终回到 u u u的时间,显然包含两个部分,一部分是 f a ( u ) fa(u) fa(u)继续往上走,一部分是到达 f a ( u ) fa(u) fa(u)之后折下来,也就是到达 u u u的兄弟节点的子树中的所有关键节点。
这里因为
v
v
v对
f
(
u
)
f(u)
f(u)的贡献方式是没有两两兄弟相关的加法(如果有两两相关的话就需要分别重新计算
u
u
u前后两部分的兄弟的贡献),所以可以想到
g
(
u
)
g(u)
g(u)的转移方式:
g
(
u
)
=
g
(
f
a
(
u
)
)
+
f
(
f
a
(
u
)
)
−
f
(
u
)
+
t
(
f
a
(
u
)
,
u
)
×
2
×
[
T
h
e
r
e
i
s
n
o
t
k
e
y
n
o
d
e
i
n
T
(
u
)
]
g(u)=g(fa(u)) + f(fa(u))-f(u)+t(fa(u),u)\times 2\times [\mathrm{There\ is\ not\ key\ node\ in\ T(u)}]
g(u)=g(fa(u))+f(fa(u))−f(u)+t(fa(u),u)×2×[There is not key node in T(u)]
特别的,如果
T
(
u
)
T(u)
T(u)中没有关键节点,那么
f
(
f
a
(
u
)
)
f(fa(u))
f(fa(u))中显然没有计算到
f
(
u
)
f(u)
f(u),那么就只能手动加上
t
(
f
a
(
u
)
,
u
)
×
2
t(fa(u),u)\times 2
t(fa(u),u)×2
此外,我们还需要求出从
u
u
u出发,到
T
(
u
)
T(u)
T(u)之外的关键节点中最远的节点的距离,表示为
m
x
u
p
(
u
)
mxup(u)
mxup(u),显然我们可以得到:
m
x
u
p
(
u
)
=
{
max
{
m
x
d
(
f
a
(
u
)
)
,
m
x
u
p
(
f
a
(
u
)
)
}
+
t
(
f
a
(
u
)
,
u
)
m
x
d
(
f
a
(
u
)
)
≠
m
x
d
(
u
)
+
t
(
f
a
(
u
)
,
u
)
max
{
s
m
d
(
f
a
(
u
)
)
,
m
x
u
p
(
f
a
(
u
)
)
}
+
t
(
f
a
(
u
)
,
u
)
o
t
h
e
r
w
i
s
e
mxup(u)=\begin{cases} \max\{mxd(fa(u)), mxup(fa(u))\}+t(fa(u),u) & mxd(fa(u)) \ne mxd(u)+t(fa(u),u)\\ \max\{smd(fa(u)), mxup(fa(u))\}+t(fa(u),u) & otherwise \end{cases}
mxup(u)={max{mxd(fa(u)),mxup(fa(u))}+t(fa(u),u)max{smd(fa(u)),mxup(fa(u))}+t(fa(u),u)mxd(fa(u))=mxd(u)+t(fa(u),u)otherwise
那么对于节点
u
u
u,答案为
a
n
s
(
u
)
=
f
(
u
)
+
g
(
u
)
−
max
{
m
x
d
(
u
)
,
m
x
u
p
(
u
)
}
ans(u)=f(u)+g(u)-\max\{mxd(u),mxup(u)\}
ans(u)=f(u)+g(u)−max{mxd(u),mxup(u)}
#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int maxn = 5e5 + 5;
struct Edge {
int v, nex; LL w;
Edge(int v = 0, LL w = 0, int nex = 0) : v(v), w(w), nex(nex) {}
} E[maxn << 1];
int n, m, isk[maxn], hsk[maxn], hd[maxn], tote;
void addedge(int u, int v, LL w) {
E[++tote] = Edge(v, w, hd[u]), hd[u] = tote;
E[++tote] = Edge(u, w, hd[v]), hd[v] = tote;
}
LL f[maxn], g[maxn], dis[maxn][3], ans[maxn];
void dfs1(int u, int fa) {
hsk[u] = isk[u];
for (int i = hd[u]; i; i = E[i].nex) {
int v = E[i].v;
if (v == fa) continue;
dfs1(v, u), hsk[u] += hsk[v];
if (!hsk[v]) continue;
f[u] += f[v] + E[i].w * 2;
if (dis[u][0] < dis[v][0] + E[i].w)
dis[u][1] = dis[u][0], dis[u][0] = dis[v][0] + E[i].w;
else if (dis[u][1] < dis[v][0] + E[i].w)
dis[u][1] = dis[v][0] + E[i].w;
}
}
void dfs2(int u, int fa) {
for (int i = hd[u]; i; i = E[i].nex) {
int v = E[i].v;
if (v == fa) continue;
if (m - hsk[v]) { //这里不能用continue直接跳过
g[v] = g[u] + f[u] - f[v];
if (!hsk[v]) g[v] += E[i].w * 2;
if (dis[v][0] + E[i].w == dis[u][0])
dis[v][2] = max(dis[u][1], dis[u][2]) + E[i].w;
else dis[v][2] = max(dis[u][0], dis[u][2]) + E[i].w;
}
dfs2(v, u);
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i < n; i++) {
int u, v; LL w; scanf("%d%d%lld", &u, &v, &w);
addedge(u, v, w);
}
for (int i = 1; i <= m; i++) {
int x; scanf("%d", &x), isk[x]++;
}
dfs1(1, 0), dfs2(1, 0);
for (int i = 1; i <= n; i++) {
ans[i] = f[i] + g[i] - max(dis[i][0], dis[i][2]);
printf("%lld\n", ans[i]);
}
return 0;
}