原题链接:https://ac.nowcoder.com/acm/contest/10272/M
题意
有n个节点,每次选择一个节点并消除,所消耗的能量是该点的权值加该儿子节点的权值和,问你有[1,m]次超能力,每次超能力可以直接消去一个节点,问使用[1,m]次超能力并消除所有节点所花费的最小能量是多少。
分析
显然贪心是无法完成的,考虑到最值、使用次数等状态问题,首先想到的就是树形dp。
我们设f[0/1][x][j]代表是否删除x节点,并且删完后子树大小还剩j的最小能量消耗。接着就是怎么推状态的转移,首先考虑子节点对当前节点的贡献
- 如果当前节点被删除,那么子节点无论如何不会对当前节点产生贡献
- 如果当前节点没被删除,当子节点也没被删除时,那么还要加上子节点的权值,否则不用加
那么最基础的转移方程就可以写出来了
dp[1][x][j] = min(dp[1][x][j], dp[1][x][j-k] + dp[1][v][k]);
dp[1][x][j] = min(dp[1][x][j], dp[1][x][j-k] + dp[0][v][k]);
dp[0][x][j] = min(dp[0][x][j], dp[0][x][j-k] + dp[1][v][k]);
dp[0][x][j] = min(dp[0][x][j], dp[0][x][j-k] + dp[0][v][k] + val[v]);
如果直接这样写,一定会tle,因为这样的复杂度是O(n^3)的
考虑怎么优化,先简化一下式子
dp[1][x][j] = min(dp[1][x][j], dp[1][x][j-k] + min(dp[1][v][k], dp[0][v][k]));
dp[0][x][j] = min(dp[0][x][j], dp[0][x][j-k] + min(dp[1][v][k], dp[0][v][k] + val[v]));
发现每次在枚举总子树个数时,可以只枚举除当前子树外的,然后去加上当前子树的贡献,这样最终被证明复杂度是O(n^2)的。
dp[1][x][j+k] = min(dp[1][x][j+k], dp[1][x][j] + min(dp[1][v][k], dp[0][v][k]));
dp[0][x][j+k] = min(dp[0][x][j+k], dp[0][x][j] + min(dp[1][v][k], dp[0][v][k] + val[v]));
Code
#include <bits/stdc++.h>
using namespace std;
//#define ACM_LOCAL
#define re register
#define fi first
#define se second
#define please_AC return 0
const int N = 1e6 + 10;
const int M = 1e6 + 10;
const int INF = 1e9;
const double eps = 1e-4;
const int MOD = 1e9+7;
typedef long long ll;
vector<int> g[N];
ll dp[2][2005][2005], val[N];
int dfs(int x, int fa) {
int sum = 1, t;
for (auto v : g[x]) {
if (v == fa) continue;
t = dfs(v, x);
for (int j = sum; j >= 0; j--) {
for (int k = 0; k <= t; k++) {
dp[1][x][j+k] = min(dp[1][x][j+k], dp[1][x][j] + min(dp[1][v][k], dp[0][v][k]));
dp[0][x][j+k] = min(dp[0][x][j+k], dp[0][x][j] + min(dp[1][v][k], dp[0][v][k] + val[v]));
}
}
sum += t;
}
return sum;
}
void solve() {
int T; scanf("%d", &T); while (T--) {
int n; scanf("%d", &n);
for (int i = 0; i <= n; i++) {
g[i].clear();
for (int j = 0; j <= n; j++) {
dp[0][i][j] = dp[1][i][j] = 1e18;
}
}
for (int i = 2; i <= n; i++) {
int x; scanf("%d", &x);
g[i].push_back(x);
g[x].push_back(i);
}
for (int i = 1; i <= n; i++) scanf("%lld", &val[i]), dp[0][i][1] = val[i], dp[1][i][0] = 0;
dfs(1, 0);
for (int i = n; i >= 0; i--) {
printf("%lld ", min(dp[0][1][i], dp[1][1][i]));
}
printf("\n");
}
}
signed main() {
#ifdef ACM_LOCAL
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#endif
solve();
return 0;
}