前置知识:
AcWing 10 树上背包——有依赖的背包问题
#include<bits/stdc++.h>
using namespace std;
const int N = 110;
int n,m,root,dp[N][N];
int idx, ne[N], e[N], h[N], w[N], v[N];
void add(int a, int b){ e[idx] = b, ne[idx] = h[a], h[a] = idx ++;}
void dfs(int u){
for(int i = h[u]; ~i; i = ne[i]){
int son = e[i];
dfs(son);
for(int j = m - v[u]; j; j --)
for(int k = 0; k <= j; ++ k)
dp[u][j] = max(dp[u][j], dp[u][j - k] + dp[son][k]);
}
for(int i = m; i >= v[u]; -- i) dp[u][i] = dp[u][i - v[u]] + w[u];
for(int i = 0; i < v[u]; ++ i) dp[u][i] = 0;
}
int main(){
cin >> n >> m;
memset(h, -1, sizeof h);
for(int i = 1; i <= n; ++ i){
int p;
cin >> v[i] >> w[i] >> p;
if(p == -1) root = i;
else add(p, i);
}
dfs(root);
cout << dp[root][m];
return 0;
}
本题tips:将题目中的统计答案变形,变成统计每一条边的边权乘上两边的黑点数量与边权乘上两边的白点数量的值的和【可以理解为:每一对在这条边两侧的不同颜色的节点, 统计答案时都会经过这条边一次。 所以总共经过的次数是两侧黑点数量积 + 两侧白点数量积】, 转化为树上背包问题。
关于循环顺序的问题:
1、对于第一层循环j从大到小, 因为省掉了分组背包中“分组”的一维, 所以只能倒叙枚举, 这是背包问题;
2、对于第二层枚举子树中黑色点的数量k, 看似必须正序枚举。在普通的树上背包中,这一层枚举的顺序无所谓, 因为k = 0表示不选, 不会对答案有什么影响; 而这个题就算子树中没有黑点, 白色点依然会对答案产生影响, 所以k = 0应该先计算, 所以要正序枚举的原因是正序枚举k是从0开始的,而这道题的状态转移必须要先将k=0的状态转移过来才能成立。
一个更好理解的解释:
“但是这道题比较特殊,就是我们的k可以等于0,这就导致对 于每一个j,最后一个k一定会进行一次非法转移。通俗点讲, 最后一个转移是:f[u][j]=max(f[u][j],f[u][j]+f[v][0]+val); 这转移肯定会发生,并且我们用的来源状态f[u][j-k]由于k=0的 原因,已经不满足我们原本要求的“我们需要的原状态不会被在这 之前更新”了,因为f[u][j]已经不知道被更新多少次了。”
所以如果先处理k = 0的情况, 其实倒序枚举也无所谓。
关于为什么这个题做树上背包需要先初始化-1的问题:——去除不合法的情况:
我们枚举在当前的子节点的子树中,我们枚举它里面有k个黑点,那么我们需要一个在其他子树中选了共 j - k 个黑点的状态,但是如果其他的子树的大小总和还不到 j - k 的话,那么这个状态显然是不合法的,所以我们要去除这种情况。
一般的树上背包中, 不存在这种情况, 因为最大体积是m, 不见的非要全部用上, 所以不需要去除不合法的情况。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 4010;
int idx, ne[N], e[N], h[N];
int n,k1,sz[N];
ll dp[N][N], w[N];
void add(int a, int b, ll c){ e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;}
void dfs(int u, int fa){
sz[u] = 1;
dp[u][0] = dp[u][1] = 0;
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(j == fa) continue;
dfs(j, u);
sz[u] += sz[j];
}
for(int i = h[u]; ~i; i = ne[i]){
int son = e[i];
if(son == fa) continue;
for(int j = min(sz[u], k1); j >= 0; j --){
for(int k = 0; k <= min(j, sz[son]); ++ k){
if(dp[u][j - k] >= 0){
ll val = (ll)k * (k1 - k) * w[i] + (ll)(sz[son] - k) * (n - k1 - sz[son] + k) * w[i];
dp[u][j] = max(dp[u][j], dp[u][j - k] + dp[son][k] + val);
}
}
}
}
}
int main(){
cin >> n >> k1;
memset(h, -1, sizeof h);
memset(dp, -1, sizeof dp);
for(int i = 1; i < n; ++ i){
int a, b; ll c;
scanf("%d%d%lld",&a,&b,&c);
add(a,b,c), add(b,a,c);
}
dfs(1,1);
cout << dp[1][k1];
return 0;
}