原题链接:http://acm.hdu.edu.cn/showproblem.php?pid=4276
题意
有一棵树,每个节点都有财富wi,每条边都会花费ci的时间,问你有T的时间,从节点1出发到节点n出口最多可以拿走多少宝物。
分析
非常明显的树形背包的模型,列出状态 f [ i ] [ j ] f[i][j] f[i][j]代表第i个节点及其子树花费j的时间最多取到的财富总和,那么状态转移的方程可以轻松写出 f [ u ] [ j ] = m a x ( f [ u ] [ j ] , f [ j − k − c o s t ] + f [ v ] [ k ] ) f[u][j]=max(f[u][j], f[j-k-cost]+f[v][k]) f[u][j]=max(f[u][j],f[j−k−cost]+f[v][k])
接着就是分类考虑当前取的边是否在最短路径上,因为根据树的性质,两点之间的路径是唯一的,因此我们一定会取一遍最短路径上的财富,我们直接跑一遍最短路将最短路径上的边权改为0。如果当前的边不在最短路径上,那么cost要乘2,因为肯定要走两遍当前边。
Code
#include <bits/stdc++.h>
using namespace std;
//#define ACM_LOCAL
#define re register
#define fi first
#define se second
#define please_AC return 0
const int N = 2000 + 10;
const int M = 5e6 + 10;
const int INF = 1e9;
const double eps = 1e-4;
const int MOD = 1e9 + 7;
typedef long long ll;
typedef unsigned long long ull;
int n, m, cnt;
int dp[105][505], h[N], val[N], dis[N], vis[N], pre[N], rnk[N];
struct Edge {
int to, next, w;
}e[M];
void add(int u, int v, int w) {
e[cnt].to = v;
e[cnt].w = w;
e[cnt].next = h[u];
h[u] = cnt++;
}
struct node {
int d, now;
bool operator < (const node &rhs) const {
return d > rhs.d;
}
};
void dij(int st) {
priority_queue<node> q;
memset(dis, 0x3f, sizeof dis);
memset(vis, 0, sizeof vis);
memset(pre, 0, sizeof pre);
dis[st] = 0; q.push({0, st});
while (q.size()) {
int now = q.top().now;
q.pop();
if (vis[now]) continue;
vis[now] = 1;
for (int i = h[now]; ~i; i = e[i].next) {
int v = e[i].to;
if (dis[v] > dis[now] + e[i].w) {
dis[v] = dis[now] + e[i].w;
pre[v] = now;
rnk[v] = i;
if (!vis[v]) q.push({dis[v], v});
}
}
}
int now = n;
while(now != 1) {
e[rnk[now]].w = e[rnk[now] ^ 1].w = 0;
now = pre[now];
}
}
void dfs(int x, int fa) {
for (int i = 0; i <= m; i++) dp[x][i] = val[x];
for (int i = h[x]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == fa) continue;
dfs(v, x);
for (int j = m; j >= e[i].w*2; j--) {
for (int k = 0; k + e[i].w * 2 <= j; k++) {
dp[x][j] = max(dp[x][j], dp[x][j-k-2*e[i].w] + dp[v][k]);
}
}
}
}
void solve() {
while(scanf("%d%d", &n, &m) != EOF) {
memset(h, -1, sizeof h); cnt = 0;
memset(dp, 0, sizeof dp);
for (int i = 1; i <= n-1; i++) {
int u, v, w; cin >> u >> v >> w;
add(u, v, w), add(v, u, w);
}
for (int i = 1; i <= n; i++) cin >> val[i];
dij(1);
if (dis[n] > m) {
printf("Human beings die in pursuit of wealth, and birds die in pursuit of food!\n");
continue;
}
m -= dis[n];
dfs(1, 0);
printf("%d\n", dp[1][m]);
}
}
signed main() {
#ifdef ACM_LOCAL
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#endif
solve();
please_AC;
}