题意:
给定一棵树,从中取
3
3
3条不相关路径(没有节点被两条及以上的路径所覆盖),问经过的点权和的最大值为多少?
拓展:
如果是取
k
k
k 条不相关路径呢?
思路:
有一个
O
(
n
k
2
)
O(nk^2)
O(nk2)的树形DP的算法。
定义数组:
$dp[i][j] $ : 以
i
i
i为根的子树包含
j
j
j条不相交路径的最大值
那么所要求取的答案便是:
d
p
[
1
]
[
3
]
dp[1][3]
dp[1][3]
定义数组:
c
n
t
[
i
]
[
j
]
cnt[i][j]
cnt[i][j]:以
i
i
i为根的子树包含
j
j
j条不相交路径 +
1
1
1条以根为端点的不相交路径的最大值
此数组的作用是便于向上转移状态,通过保证有一条以根为端点的路径,则可以继续拓展连接到根的父亲,进而进行状态转移。
随后便是树形DP的过程,对于以
u
u
u为根的子树,如何合并其儿子们的状态呢?
对于某一个儿子节点
v
v
v,其可以贡献一条任意路径,也可以贡献两条任意路径,或三条任意路径,或一条任意路径 + 一个以根为端点的路径,或者两条任意路径 + 一个端点的路径。
故每个儿子存在很多种贡献情况,我们可以考虑再次DP,设:
s
u
m
[
k
]
[
i
]
[
j
]
sum[k][i][j]
sum[k][i][j]:前
k
k
k个儿子所构成的子树林包含
j
j
j条不相交路径和
i
i
i条以根为端点的不相交链的最大值
则根据以第
k
k
k个儿子为根的子树贡献路径的情况,
s
u
m
[
k
]
sum[k]
sum[k]的值可以由
s
u
m
[
k
−
1
]
sum[k-1]
sum[k−1]转移而来,故该数组可以使用滚动数组。
此题得解
代码:
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
typedef long long ll;
const int A = 5e5 + 10;
const int B = 5;
class Gra{
public:
int v,next;
}G[A<<1];
int head[A], val[A], tot, n;
ll dp[A][B];
ll cnt[A][B];
void Init(){
memset(head, -1, sizeof(head));
tot = 0;
}
void add(int u, int v){
G[tot].v = v;
G[tot].next = head[u];
head[u] = tot++;
}
void dfs(int u, int pre){
int twt = 0;
ll sum[2][B][B] = {0}, tmp[B][B] = {0};
for (int i = head[u]; i != -1; i = G[i].next) {
int v = G[i].v;
if (v == pre) continue;
dfs(v, u); twt++;
for (int i = 0; i <= 2; i++) {
for (int j = 0; j <= 3; j++) {
tmp[i][j] = 0;
sum[twt&1][i][j] = 0;
}
}
for (int j = 0; j <= 3; j++) {
tmp[0][j] = dp[v][j];
tmp[1][j] = cnt[v][j];
}
for (int x = 0; x <= 2; x++) {
for (int y = 0; x + y <= 2; y++) {
for (int p = 0; p <= 3; p++) {
for (int q = 0; p + q <= 3; q++) {
sum[twt&1][x + y][p + q] = max(sum[twt&1][x + y][p + q], sum[(twt&1)^1][x][p] + tmp[y][q]);
}
}
}
}
}
//update cnt[][]
for (int i = 0; i <= 1; i++) {
for (int j = 0; j <= 3; j++) {
cnt[u][j] = max(cnt[u][j], val[u] + sum[twt&1][i][j]);
}
}
// update dp[][]
for (int i = 0; i <= 2; i++) {
for (int j = 1; j <= 3; j++) {
dp[u][j] = max(dp[u][j], val[u] + sum[twt&1][i][j-1]);
}
}
for (int j = 0; j <= 3; j++) dp[u][j] = max(dp[u][j], sum[twt&1][0][j]);
}
int main(){
while(~scanf("%d", &n)){
Init();
for (int i = 1; i <= n; i++) {
scanf("%d", &val[i]);
}
for (int i = 1; i < n; i++) {
int u,v;
scanf("%d%d",&u, &v);
add(u, v);
add(v, u);
}
dfs(1, 1);
printf("%lld\n",dp[1][3]);
}
return 0;
}