树形dp中,我们可以很容易利用题目给出的数据构建出一棵树,然后我们会用dp[i]记录,选择以 i 为根的子树中的点能得到的最优答案。而dp[i]的值可以根据,i节点的子节点u1,u2...un对应的dp[u1],dp[u2]...dp[un],递推得到。树形dp通常用搜索的形式实现,由于是树的结构,所以不包含重复的搜索,复杂度为O(n)。
我建树是建边用的是邻接表,做图论题的时候也是这样,感觉挺方便的。看了网上一些题解是用一个数组 a[i] 记录与 i 有同一个父节点的下一个节点。
下面是一些树形dp入门题
题目:POJ 1655 http://poj.org/problem?id=1655
题意:定义树上每个节点的有一个平衡值,平衡值为删除这个点后得到的树的大小的最大值。给出一棵树,求出平衡值最小的节点和它对应的平衡值。
思路:dp[ i ] 记录以 i 为根节点的树的大小。求每个节点的平衡值,要求以它的子节点为跟的子树的大小,和总节点数-以它为根节点的树的大小。
代码:
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#include <stdlib.h>
#define INF 0x7fffffff
#define MOD 1000000007
#include <vector>
using namespace std;
typedef long long ll;
vector<int> G[20005];
int d[20005], vis[20005];
void dp(int u)
{
d[u] = 1;
vis[u] = 1;
for(int i = 0; i < G[u].size(); i++)
{
int v = G[u][i];
if(vis[v])
{ //v是u的父节点而不是子节点
G[u][i] = -1;
continue;
}
dp(v);
d[u] += d[v];
}
//printf("d[%d]=%d\n", u, d[u]);
return;
}
int main()
{
#ifdef LOCAL
freopen("dpdata.txt", "r", stdin);
#endif
int t, n, a, b;
scanf("%d", &t);
while(t--)
{
scanf("%d", &n);
for(int i = 0; i <= n; i++)
{
G[i].clear();
}
for(int i = 0; i < n - 1; i++)
{
scanf("%d%d", &a, &b);
G[a].push_back(b);
G[b].push_back(a);
}
memset(d, 0, sizeof(d));
memset(vis, 0, sizeof(vis));
dp(1);
int ans = INF, p;
for(int i = 1; i <= n; i++)
{
int cur = d[1] - d[i];
for(int j = 0; j < G[i].size(); j++)
{
int k = G[i][j];
if(k == -1) continue;
cur = max(cur, d[k]);
}
if(cur < ans)
{
ans = cur; p = i;
}
}
printf("%d %d\n", p, ans);
}
return 0;
}
题目:HDU 1520 http://acm.hdu.edu.cn/showproblem.php?pid=1520
题意:办一个派对,邀请员工的时候,如果邀请了一个上司,那他的直接下属不能邀请了,每个人都有一个参与派对的欢乐值,要求出可以得到的派对欢乐值最大是多少。
思路:每个人选和不选都会决定他的直接下属是否可选,所以对于每个人要分两种情况讨论然后求最大值,要使用二维数组,dp[u][0] 记录不选 u 能得到的最大值,那么他的直接下属也有选和不选两种情况,dp[u][1] 记录选 u 能得到的最大值,那么他的直接下属只有不选的情况。
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#include <stdlib.h>
#define INF 0x7fffffff
#define MOD 1000000007
#include <vector>
using namespace std;
typedef long long ll;
int con[6005], in[6005];
vector<int> G[6005];
int d[6005][5];//0²»È¡£¬1È¡
void dp(int u)
{
//if(d[u][0] || d[u][1]) return;
d[u][0] = 0; d[u][1] = con[u];
for(int i = 0; i < G[u].size(); i++)
{
int v = G[u][i];
dp(v);
d[u][0] += max(d[v][0], d[v][1]);
d[u][1] += d[v][0];
}
//printf("d[%d][0]=%d,d[%d][1]=%d\n", u, d[u][0], u, d[u][1]);
return;
}
int main()
{
#ifdef LOCAL
freopen("dpdata.txt", "r", stdin);
#endif
int n, l, k;
while(scanf("%d", &n) != EOF)
{
for(int i = 1; i <= n; i++)
{
scanf("%d", &con[i]);
G[i].clear();
}
memset(in, 0, sizeof(in));
while(scanf("%d%d", &l, &k))
//for(int i = 0; i < n - 1; i++)
{
//scanf("%d%d", &l, &k);
if(l == 0 && k == 0) break;
G[k].push_back(l);
in[l]++;
}
int v;
for(int i = 1; i <= n; i++)
{
if(in[i] == 0)
{
v = i; break;
}
}
memset(d, 0, sizeof(d));
dp(v);
printf("%d\n", max(d[v][0], d[v][1]));
}
return 0;
}
题意:在一个地图上,有N座城堡,每座城堡都有一定的宝物,在每次游戏中ACboy允许攻克M个城堡并获得里面的宝物。求ACboy能得到的最大宝物数。
思路:地图就是一棵树,根节点为ACboy的起点0。dp[u][m]记录以u为起点,最多到达m个点,能得到的最大宝物数。类似背包问题中到达的点数为重量,宝物数为价值,容量为总重量。
代码:发现这个代码有记录价值的数组v[ ](全局),也有表示子节点的v(for循环里面),懒得改了,不影响运行结果。后面几个代码也是有这个问题。
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#include <stdlib.h>
#define INF 0x7fffffff
#define MOD 1000000007
#include <vector>
using namespace std;
typedef long long ll;
int n, m, v[205], sum[205], dp[205][205];
vector<int> G[205];
void dfs(int u)
{
sum[u] = 1;
dp[u][1] = v[u];
for(int i = 0; i < G[u].size(); i++)
{
int v = G[u][i];
dfs(v);
sum[u] += sum[v]; //当前以u为根节点可到达的节点数
for(int j = sum[u]; j >= 2; j--)
{
for(int k = 1; k <= sum[v]; k++)
{ //确定攻克了u,v为u子节点,枚举在以v为根节点的树上攻克的节点数
if(j - k < 1) continue;
dp[u][j] = max(dp[u][j], dp[u][j - k] + dp[v][k]);
//printf("(v%d) dp[%d][%d]=%d\n", v, u, j, dp[u][j]);
}
}
}
}
int main()
{
#ifdef LOCAL
freopen("dpdata.txt", "r", stdin);
#endif
int f;
while(scanf("%d%d", &n, &m) != EOF && (n && m))
{
for(int i = 0; i <= n; i++)
{
G[i].clear();
}
v[0] = 0;
for(int i = 1; i <= n; i++)
{
scanf("%d%d", &f, &v[i]);
G[f].push_back(i);
}
memset(dp, 0, sizeof(dp));
dfs(0);
printf("%d\n", dp[0][m + 1]); //起点为一个到达的点但不需要攻克
}
return 0;
}
题目:HDU 1011 http://acm.hdu.edu.cn/showproblem.php?pid=1011
题意:一群士兵在地洞里打虫子,他们想抓住虫子头目,他们每到了一个洞要走到更深的洞的时候之前都要留下一些士兵在当前的洞打虫子。士兵的数量有限,他们知道每个洞能找到虫子头目的概率,问他们抓住虫子头目的最大概率。
思路:也是在一个树形的地图里,起点为根节点。和上面一题差不多,不同的是每个点的权值不同(上面一题每个点权值都为1)。dp[u][w] 表示以 u 为起点用 w 个士兵找到虫子头目的最大概率。类似背包问题中需要士兵数为重量,概率为价值,容量为总士兵数。
代码:
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#include <stdlib.h>
#define INF 0x7fffffff
#define MOD 1000000007
#include <vector>
using namespace std;
typedef long long ll;
vector<int> G[205];
int n, m;
int w[105], v[105], vis[105], dp[105][105];
void dfs(int u)
{
vis[u] = 1;
for(int i = w[u]; i <= m; i++)
dp[u][i] = v[u];
//printf("d[%d][%d]=%d\n", u, m - w[u], dp[u][m - w[u]]);
for(int i = 0; i < G[u].size(); i++)
{
int v = G[u][i];
if(vis[v]) continue;
dfs(v);
for(int j = m; j >= w[u]; j--)
{
for(int k = 1; j - k >= w[u]; k++)
{
dp[u][j] = max(dp[u][j], dp[u][j - k] + dp[v][k]);
//printf("(v%d) d[%d][%d]=%d %d\n", v, u, j, dp[u][j], k);
}
}
}
}
int main()
{
#ifdef LOCAL
freopen("dpdata.txt", "r", stdin);
#endif
int a, b;
while(scanf("%d%d", &n, &m) != EOF)
{
if(n == -1 && m == -1) break;
for(int i = 1; i <= n; i++)
{
scanf("%d%d", &w[i], &v[i]);
w[i] = (w[i] + 20 - 1) / 20;
G[i].clear();
}
for(int i = 0; i < n - 1; i++)
{
scanf("%d%d", &a, &b);
G[a].push_back(b);
G[b].push_back(a);
}
memset(vis, 0, sizeof(vis));
memset(dp, 0, sizeof(dp));
if(m == 0)
{
printf("0\n");
continue;
}
dfs(1);
printf("%d\n", dp[1][m]);
}
return 0;
}
题目:POJ 1155 http://poj.org/problem?id=1155
题意:电视台要见一个电视网,电视网友发射站和用户接收端组成,终端为一个发射站,一个发射站可以连接其他发射站或用户,用户只能连发射站。连接每条线都有一定的费用,每个用户付的费用为一定值,求电视台在不亏本的情况下最多能连接多少个用户。
思路:前面两题的费用都为点的权值,这题的费用为边权,价值只在“用户”这些点上。dp[ u ] [ j ] 记录以u为起点连接 j 个用户的最大收益(初始化为负值),类似0-1背包中用价值作为容量的情况。
代码:
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#include <stdlib.h>
#define INF 0x7fffffff
#define MOD 1000000007
#include <vector>
using namespace std;
typedef long long ll;
int n, m, vis[3005], sum[3005], dp[3005][3005];
struct Edge
{
int to, cost;
}cur;
vector<Edge> G[3005];
void dfs(int u)
{
vis[u] = 1;
if(u <= n - m) sum[u] = 0;
else sum[u] = 1;
for(int i = 0; i < G[u].size(); i++)
{
Edge &e = G[u][i];
if(vis[e.to]) continue;
dfs(e.to);
//printf("v=%d\n", e.to);
sum[u] += sum[e.to];
dp[u][0] = 0;
for(int j = sum[u]; j >= 1; j--)
{
for(int k = 1; k <= sum[e.to]; k++)
{
if(j - k < 0) break;
dp[u][j] = (int)max((ll)dp[u][j], (ll)dp[u][j - k] + dp[e.to][k] - e.cost);
//printf("dp[%d][%d]=%d\n", u, j, dp[u][j]);
}
}
}
}
int main()
{
#ifdef LOCAL
freopen("dpdata.txt", "r", stdin);
#endif
int k, a, c;
while(scanf("%d%d", &n, &m) != EOF)
{
for(int i = 0; i <= n; i++)
{
G[i].clear();
}
for(int i = 1; i <= n - m; i++)
{
scanf("%d", &k);
for(int j = 0; j < k; j++)
{
scanf("%d%d", &a, &c);
cur.to = a; cur.cost = c;
G[i].push_back(cur);
}
}
memset(sum, 0, sizeof(sum));
memset(vis, 0, sizeof(vis));
for(int i = 0; i <= n; i++)
{
for(int j = 0; j <= m; j++)
{
dp[i][j] = -INF;
}
}
for(int i = n - m + 1; i <= n; i++)
{
scanf("%d", &dp[i][1]);
}
dfs(1);
for(int i = m; i >= 0; i--)
{
if(dp[1][i] >= 0)
{
printf("%d\n", i);
break;
}
}
}
return 0;
}
题目:ZOJ 3626 http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3626
题意:有一个危险的村庄,这里有时会有怪兽出现,已知怪兽多久会出现一次,怪兽出现之后发现谁不在家,就会把谁杀掉。一个勇士要得到了一张藏宝图并要去找宝藏,他知道村庄之间的距离和每个村庄里有多少宝藏,问他不被怪兽杀掉的话,最多能拿到多少宝藏。给出的n个村庄和n-1条路,村庄两两之间可以到达。
思路:地图是树形的,勇士所在的村庄为根节点,边权为重量,价值在点上,怪兽出现的时间的为容量。dp[u][j] 记录从 u 出发回到 u 用 j 天最多能收集到的宝藏数。
代码:
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#include <stdlib.h>
#define INF 0x7fffffff
#define MOD 1000000007
#include <vector>
using namespace std;
typedef long long ll;
int n, k, m, v[105], vis[105], dp[105][205];
struct Edge
{
int to, cost;
}cur;
vector<Edge> G[105];
void dfs(int u)
{
vis[u] = 1;
for(int i = 0; i <= m; i++)
dp[u][i] = v[u];
//dp[u][0] = v[u];
for(int i = 0; i < G[u].size(); i++)
{
Edge &e = G[u][i];
if(vis[e.to]) continue;
dfs(e.to);
//printf("v=%d\n", e.to);
for(int j = m; j >= 1; j--)
{
for(int k = 0; j - k - 2 * e.cost >= 0; k++)
{
dp[u][j] = max(dp[u][j], dp[u][j - k - 2 * e.cost] + dp[e.to][k]);
//printf("dp[%d][%d]=%d\n", u, j, dp[u][j]);
}
}
}
}
int main()
{
#ifdef LOCAL
freopen("dpdata.txt", "r", stdin);
#endif
int a, b, c;
while(scanf("%d", &n) != EOF)
{
for(int i = 1; i <= n; i++)
{
scanf("%d", &v[i]);
G[i].clear();
}
for(int i = 0; i < n - 1; i++)
{
scanf("%d%d%d", &a, &b, &c);
cur.cost = c; cur.to = a;
G[b].push_back(cur);
cur.to = b;
G[a].push_back(cur);
}
scanf("%d%d", &k, &m);
//m /= 2;
memset(vis, 0, sizeof(vis));
memset(dp, 0, sizeof(dp));
dfs(k);
printf("%d\n", dp[k][m]);
}
return 0;
}