题目
思路
树形DP
先注意到数据范围可知二维dp可行。
定义
d
p
[
v
]
[
j
]
dp[v][j]
dp[v][j]为使得以
v
v
v为端点的所有链长不超过
j
j
j并且树上的任意链长不超过
l
l
l的最小代价。
特别的,
d
p
[
v
]
[
0
]
dp[v][0]
dp[v][0]代表删除节点
v
v
v。
状态转移:
1)删除节点
v
v
v:
d
p
[
v
]
[
0
]
=
∑
d
p
[
u
]
[
j
]
∀
u
∈
S
o
n
v
dp[v][0] = \sum\ dp[u][j] \ \ \forall\ u\in Son_v
dp[v][0]=∑ dp[u][j] ∀ u∈Sonv
(
j
<
l
)
(j<l)
(j<l)
2)保留节点
v
v
v:
d
p
[
v
]
[
j
]
=
d
p
[
k
]
[
j
−
1
]
+
∑
d
p
[
u
]
[
m
i
n
(
l
−
j
−
1
,
j
−
1
)
)
]
dp[v][j]= dp[k][j-1]+\sum\ dp[u][min(l-j-1,j-1))]
dp[v][j]=dp[k][j−1]+∑ dp[u][min(l−j−1,j−1))],其中
k
∪
u
=
S
o
n
v
k\cup u=Son_v
k∪u=Sonv
如果要删除一个点,那么这棵树就不再连通,此时的问题就成了使得这个点每一棵子树的最长链都不大于 l l l。
如果保留这个点,我们知道一个树上最长的链可能是其中一棵子树的最长链与另一棵子树的最长链加上根节点自身所组成的。那么我们可以让其中一颗子树的最长链不大于 j − 1 j-1 j−1,而让其他所有子树的最长链为 l − j − 1 l-j-1 l−j−1即可,这样树上任意两条链之和都不会大于 l l l,注意有一个限制条件,到子节点的链长一定小于到当前根节点的链长,那么取 m i n ( l − j − 1 , j − 1 ) min(l-j-1, j-1) min(l−j−1,j−1)就好了。然后我们就可以枚举不同的 k k k(选择哪一棵子树使他的长度为 j − 1 j-1 j−1)得到最优解。
同时注意到在这个问题中根的位置不会对答案有影响,选取1为根就好了。
更新的时候需要有
d
p
[
v
]
[
j
]
=
m
i
n
(
d
p
[
v
]
[
j
]
,
d
p
[
v
]
[
j
−
1
]
)
dp[v][j] = min(dp[v][j], dp[v][j-1])
dp[v][j]=min(dp[v][j],dp[v][j−1])保证答案的完备性。(见
d
p
[
v
]
[
j
]
dp[v][j]
dp[v][j]的定义)。
那么根据上面就可得知对于删除节点时的操作,我们不需要取一个一个尝试子问题中
j
j
j的取值,直接取
l
−
1
l-1
l−1就可以,因为对于任何一个
j
j
j最优值已经包含在了
l
−
1
l-1
l−1里面。
同样的道题,最后我们得到答案时,也不需要枚举 d p [ 1 ] [ j ] dp[1][j] dp[1][j],直接输出 d p [ 1 ] [ l − 1 ] dp[1][l-1] dp[1][l−1]
#include <iostream>
#include <vector>
using namespace std;
#define pb push_back
const int maxn = 5000+10;
int a[maxn], n, l;
int dp[maxn][maxn];
vector<int> G[maxn];
void dfs(int v, int fa) {
dp[v][0] = a[v];
for ( auto u : G[v]) {
if (u == fa) continue;
dfs(u, v);
dp[v][0] += dp[u][l-1];
}
// 所有子树中任意两个链长和不能超过l
// 子树不能超过j-1,但是也不能大于j
for (int i = 1; i < l; i++) {
int tmp = 0;
for (auto u : G[v]) {
if (u == fa) continue;
//利用tmp,更新选择不同的k所得到的结果
dp[v][i] = min(tmp + dp[u][i-1], dp[v][i] + dp[u][min(i-1, l-i-1)]);
tmp += dp[u][min(i-1, l-i-1)];
}
//答案完备性
dp[v][i] = min(dp[v][i], dp[v][i-1]);
}
}
int main() {
scanf("%d%d", &n, &l);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for (int i = 0; i < n-1; i++){
int a, b; scanf("%d%d", &a, &b);
G[a].pb(b); G[b].pb(a);
}
dfs(1, -1);
printf("%d", dp[1][l-1]);
return 0;
}