题目大意:
解题思路:
- 首先我们分析一下每个点状态是怎么?
1.1: 对于这个点我们删除的代价我们要看一下它儿子有多少个没被删除(指没用)因为父亲节点没删这个点呀删除不了
1.2: 那么这个点下面用了多少次魔法。
1.3: 这个点是否使用用魔法去删除
那么
d
p
dp
dp方程就很显然了
d
p
[
i
]
[
j
]
[
0
/
1
]
dp[i][j][0/1]
dp[i][j][0/1]表示第i个点为子树里面用了j次魔法并且第i个点是否用魔法去删除并且删除完整个子树的最小代价是多少?
那么
d
p
dp
dp转移方程就有了
for(int i = 0;i <= siz[x] + siz[v];i++) {// 临时数组
tmp[i][0] = tmp[i][1] = INF;
}
for(int i = 0;i <= siz[x];i++) {
for(int j = 0;j <= siz[v];j++) {
tmp[i + j][0] = min(tmp[i + j][0],dp[x][i][0] + dp[v][j][0] + hp[v]);
// 在前面子树里面用了i次魔法并且x点不用魔法删除,并且v点子树里面使用了j次魔法,并且v点不使用魔法删除,因为v点不使用魔法删除那么v肯定是在x点之后删除的那么要加上这个v点的贡献
if(j > 0) {
tmp[i + j][0] = min(tmp[i + j][0],dp[x][i][0] + dp[v][j][1]);
}
if(i > 0) {
tmp[i + j][1] = min(tmp[i + j][1],dp[x][i][1] + dp[v][j][0]);
}
if(i > 0 && j > 0) {
tmp[i + j][1] = min(tmp[i + j][1],dp[x][i][1] + dp[v][j][1]);
}
}
}
for(int j = 0;j <= siz[x] + siz[v];j++) {
dp[x][j][0] = tmp[j][0];
dp[x][j][1] = tmp[j][1];// 复制回去
}
注意更新方式,每次只能去遍历新加的节点这样才能保证复杂度是 O ( n 2 ) O(n^2) O(n2)
AC code
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 2e3 + 7;
typedef long long ll;
const ll INF = 1e18;
vector<int>G[maxn];
ll dp[maxn][maxn][2],tmp[maxn][2];
ll hp[maxn];
int siz[maxn];
void dfs(int x) {
siz[x] = 1;
for(int t = 0;t < G[x].size();t++) {
int v = G[x][t];
dfs(v);
for(int i = 0;i <= siz[x] + siz[v];i++) {
tmp[i][0] = tmp[i][1] = INF;
}
for(int i = 0;i <= siz[x];i++) {
for(int j = 0;j <= siz[v];j++) {
tmp[i + j][0] = min(tmp[i + j][0],dp[x][i][0] + dp[v][j][0] + hp[v]);
if(j > 0) {
tmp[i + j][0] = min(tmp[i + j][0],dp[x][i][0] + dp[v][j][1]);
}
if(i > 0) {
tmp[i + j][1] = min(tmp[i + j][1],dp[x][i][1] + dp[v][j][0]);
}
if(i > 0 && j > 0) {
tmp[i + j][1] = min(tmp[i + j][1],dp[x][i][1] + dp[v][j][1]);
}
}
}
for(int j = 0;j <= siz[x] + siz[v];j++) {
dp[x][j][0] = tmp[j][0];
dp[x][j][1] = tmp[j][1];
}
siz[x] += siz[v];
}
for(int i = 0;i <= siz[x];i++) {
dp[x][i][0] += hp[x];
}
}
int main() {
int T;scanf("%d",&T);
while(T--) {
int n;scanf("%d",&n);
for(int i = 1;i <= n;i++) {
for(int j = 0;j <= n;j++) {
dp[i][j][0] = dp[i][j][1] = 0;
}
}
for(int i = 1;i <= n;i++) G[i].clear();
for(int i = 2;i <= n;i++) {
int f;scanf("%d",&f);
G[f].push_back(i);
}
for(int i = 1;i <= n;i++) {
scanf("%lld",&hp[i]);
}
dfs(1);
for(int i = 0;i <= n;i++) {
printf("%lld ",min(dp[1][i][0],dp[1][i][1]));
}
printf("\n");
}
return 0;
}