题目链接
http://acm.hust.edu.cn/vjudge/problem/32271
思路
树形dp,树形dp一般要用一维来表示当前节点/父节点的访问状态
状态表示: d[u][i][s] 当前在第i个节点,访问了i个节点,并且是否回到当前节点(s == 0 or 1)所走过的最少距离
分析:分为回到当前节点或不会回到当前节点,设u为当前节点,v为u的一个儿子节点
1. 回到当前节点,s == 1:
若回到当前节点,那么必然访问子树v回到当前节点,以及访问其他子树回到当前节点。
转移方程:d[u][i][1] = min(d[u][i][1], d[u][i - k][1] + d[v][k][1] + 2 * w(u, v))
2. 不回到当前节点
若最后没有回到当前节点,包括:
I. 从其他节点回到u,在从u出发到v。
转移方程:d[u][i][0] = min(d[u][i][0], d[u][i - k][1] + d[v][k][0] + w(u, v))
II. 从v回到u,并从u出发访问其他节点
转移方程:d[u][i][0] = min(d[u][i][0], d[u][i - k][0] + d[v][k][1] + 2 * w(u, v))
细节
- 在代码56行枚举i时,由转移方程可以看出第一维都是u,第二维的i是向前依赖的,因此每次更新d[u][i][s]的时候都会覆盖掉之前(同0-1背包),相当于在同一行上滚动,因此应该从后向前枚举
- 因为每个节点的转移只到其儿子节点上,因此可将该树的边建成有向边。
代码
#include <iostream>
#include <cstring>
#include <stack>
#include <vector>
#include <set>
#include <map>
#include <cmath>
#include <queue>
#include <sstream>
#include <iomanip>
#include <fstream>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <deque>
#include <bitset>
#include <algorithm>
using namespace std;
#define PI acos(-1.0)
#define LL long long
#define PII pair<int, int>
#define PLL pair<LL, LL>
#define mp make_pair
#define IN freopen("in.txt", "r", stdin)
#define OUT freopen("out.txt", "wb", stdout)
#define scan(x) scanf("%d", &x)
#define scan2(x, y) scanf("%d%d", &x, &y)
#define scan3(x, y, z) scanf("%d%d%d", &x, &y, &z)
#define sqr(x) (x) * (x)
#define pr(x) cout << #x << " = " << x << endl
#define lc o << 1
#define rc o << 1 | 1
#define pl() cout << endl
#define CLR(a, x) memset(a, x, sizeof(a))
#define FILL(a, n, x) for (int i = 0; i < n; i++) a[i] = x
const int maxn = 505;
const int INF = 0x3e3e3e3e;
vector<int> G[maxn];
int n, sum, a[maxn][maxn], d[maxn][maxn][2], son[maxn];
void init() {
for (int i = 0; i < maxn; i++) G[i].clear();
}
void dfs(int u) {
for (int i = 1; i <= n; i++) d[u][i][0] = d[u][i][1] = INF;
d[u][1][0] = d[u][1][1] = 0;
son[u] = 1;
for (int j = 0; j < G[u].size(); j++) {
int v = G[u][j];
int len = a[u][v];
dfs(v);
son[u] += son[v];
for (int i = son[u]; i >= 1; i--) {
for (int k = 1; k <= son[v] && k < i; k++) {
d[u][i][0] = min(d[u][i][0], min(d[u][i - k][1] + d[v][k][0] + len, d[u][i - k][0] + d[v][k][1] + 2 * len));
d[u][i][1] = min(d[u][i][1], d[u][i - k][1] + d[v][k][1] + 2 * len);
}
}
}
}
int main() {
int kase = 0;
while (scan(n) && n) {
init();
for (int i = 0; i < n - 1; i++) {
int x, y, dist;
scan3(x, y, dist);
G[y].push_back(x);
a[x][y] = a[y][x] = dist;
}
dfs(0);
int Q, x;
scan(Q);
printf("Case %d:\n", ++kase);
while (Q--) {
scan(x);
for (int k = n; k >= 1; k--) {
if (x >= d[0][k][0]) {
printf("%d\n", k);
break;
}
}
}
}
return 0;
}