树上背包问题
一些题目给定了树形结构,在这个树形结构中选取一定数量的点或边(也可能是其他属性),使得某种与点权或者边权相关的花费最大或者最小。解决这类问题,一般要考虑使用树上背包。
算法原理
树上背包,顾名思义,就是在树上做背包问题。一个节点的若干子树可以看作是若干组背包,也就是用树形dp的方式做分组背包问题。一般来说, f ( i , j ) f(i,j) f(i,j)表示以 i i i为根的子树中,在 j j j的容量范围内,最大或者最小可以获得多少收益。根据分组背包的思想,第一维枚举物品(在树上指的是子树),第二维枚举容量,第三维枚举决策(这里指的是给子树分配多少容量)。基本的代码框架如下:
void dfs(int u, int fa)
{
for(int i = h[u]; ~i; i = ne[i])
{
int son = e[i];
if(son == fa) continue;
dfs(son, u);
for(int j = m; j >= 0; j --)
for(int k = 0; k <= j; k ++)
f[u][j] = max(f[u][j], f[u][j-k] + f[son][k] + val);
}
}
例题一:有依赖的背包问题
题意
有
n
n
n个物品和一个容量是
m
m
m的背包。物品之间具有依赖关系,且依赖关系组成一棵树的形状。如果选择一个物品,则必须选择它的父节点。
求解将哪些物品装入背包,可使物品总体积不超过背包容量,且总价值最大。输出最大价值。
每件物品的编号是
i
i
i,体积是
v
i
v_i
vi,价值是
w
i
w_i
wi,依赖的父节点编号是
p
i
p_i
pi。物品的下标范围是
1
…
N
1 \dots N
1…N。
数据范围
1
≤
n
,
m
≤
100
1 \leq n,m \leq 100
1≤n,m≤100
1
≤
v
i
,
w
i
≤
100
1 \leq v_i,w_i \leq 100
1≤vi,wi≤100
思路
f
(
i
,
j
)
f(i,j)
f(i,j)表示选择以
i
i
i为子树的物品,在容量不超过
j
j
j时所获得的最大价值。
由于只有选择了根节点,才会继续往下遍历,所以在遍历到
i
i
i节点时,先考虑一定选上它。
在分组背包部分,
j
j
j的范围为
[
m
,
v
[
i
]
]
[m,v[i]]
[m,v[i]],否则没有意义,因为连根节点也放不下;
k
k
k的范围
[
0
,
j
−
v
[
i
]
]
[0,j-v[i]]
[0,j−v[i]],当大于
j
−
v
[
i
]
j-v[i]
j−v[i]时分给该子树的容量过多,剩余的容量连根节点的物品都放不下了。
递推式为:
f
(
i
,
j
)
=
m
a
x
(
f
(
i
,
j
)
,
f
(
i
,
j
−
k
)
+
f
(
s
o
n
,
k
)
)
f(i,j) = max(f(i,j), f(i,j - k) + f(son,k))
f(i,j)=max(f(i,j),f(i,j−k)+f(son,k))。
代码
void dfs(int u)
{
for(int i = v[u]; i <= m; i ++) f[u][i] = w[u];
for(int i = h[u]; ~i; i = ne[i])
{
int son = e[i];
dfs(son);
for(int j = m; j >= v[u]; j --)
for(int k = 0; k <= j - v[u]; k ++)
f[u][j] = max(f[u][j], f[u][j - k] + f[son][k]);
}
}
例题二:二叉苹果树
题意
给定一棵二叉树,每条边有边权,保留一定数量的边(其他边删除),使得保留下来的边的边权和最大。
数据范围
1
≤
n
<
m
≤
100
1 \leq n < m \leq 100
1≤n<m≤100
w
i
≤
30000
w_i \leq 30000
wi≤30000
思路
f
(
i
,
j
)
f(i,j)
f(i,j)表示以
i
i
i为根的子树中,恰好保留
j
j
j条边的最大边权和。
若需要选择该子树中的边,则根结点到子树的边一定要选,因此能用上的总边数一定减
1
1
1,总共可以选择
j
j
j条边时,当前子树son分配的最大边数是
j
−
1
j - 1
j−1。
递推式为,
f
(
i
,
j
)
=
m
a
x
(
f
(
i
,
j
)
,
f
(
i
,
j
−
k
−
1
)
+
f
(
s
o
n
,
k
)
+
w
[
i
]
)
f(i,j) = max(f(i,j), f(i,j-k-1) + f(son, k) + w[i])
f(i,j)=max(f(i,j),f(i,j−k−1)+f(son,k)+w[i])。
代码
void dfs(int u, int fa)
{
for(int i = h[u]; ~i; i = ne[i])
{
int son = e[i];
if(son == fa) continue;
dfs(son, u);
for(int j = m; j >= 1; j -- )
for(int k = 0; k <= j - 1; k ++ )
f[u][j] = max(f[u][j], f[u][j - k - 1] + f[son][k] + w[i]);
}
}
例题三:Factories(2018icpc银川网络赛)
题意
给定一棵树,边有边权。每个叶子节点上最多可以布置一个工厂,总共要布置 k k k个工厂。问怎样布置工厂,使得工厂之间的距离和最小。
数据范围
10
s
10s
10s
2
≤
n
≤
1
0
5
2 \leq n \leq 10^5
2≤n≤105,
1
≤
m
≤
100
1 \leq m \leq 100
1≤m≤100
1
≤
w
i
≤
1
0
5
1 \leq w_i \leq 10^5
1≤wi≤105
多组测试数据,
n
n
n总数不超过
1
0
6
10^6
106
思路
直接考虑距离之和非常困难,所以可以考虑每条边被计算了几次(距离和等类似问题很多都是这么考虑的)。不妨设一条边为
i
i
i,与
i
i
i相连的子树中有
j
j
j个工厂,则这条边被计算的次数为
j
∗
(
m
−
j
)
j*(m - j)
j∗(m−j)。
f
(
i
,
j
)
f(i,j)
f(i,j)表示以
i
i
i为根节点的子树中,选择恰好
j
j
j个叶子节点的距离总和。
递推式为,
f
(
i
,
j
)
=
m
i
n
(
f
(
i
,
j
)
,
f
(
i
,
j
−
k
)
+
f
(
s
o
n
,
k
)
+
w
[
i
]
∗
j
∗
(
m
−
j
)
)
f(i,j) = min(f(i,j), f(i,j - k) + f(son, k) + w[i] * j * (m - j))
f(i,j)=min(f(i,j),f(i,j−k)+f(son,k)+w[i]∗j∗(m−j))。
因为只能分布在叶子节点,因此初始化的时候要注意,如果点
i
i
i为叶子节点,那么
f
(
i
,
1
)
=
0
f(i,1) = 0
f(i,1)=0。
同时这道题要卡常数,所以要对状态做一个优化,即把无效状态去掉。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 100003, M = 103;
const ll inf = 1e18;
int n, m;
int h[N], e[2*N], ne[2*N], w[2*N], idx;
int s[N], deg[N];
ll f[N][M];
void add(int a,int b,int c)
{
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++;
}
void dfs(int u,int fa)
{
for(int i = h[u]; ~i; i = ne[i])
{
int son = e[i];
if(son == fa) continue;
dfs(son, u);
s[u] += s[son];
for(int j = min(m, s[u]); j >= 1; j --)
for(int k = 1; k <= min(j, s[son]); k ++)
f[u][j] = min(f[u][j], f[u][j-k] + f[son][k] + (ll)w[i] * k * (m - k));
}
}
int main()
{
int T;
scanf("%d", &T);
int cas = 0;
while(T --)
{
scanf("%d%d", &n,&m);
for(int i = 1; i <= n; i ++) h[i] = -1, deg[i] = 0;
idx = 0;
for(int i = 0; i < n - 1; i ++)
{
int a,b,c;
scanf("%d%d%d", &a,&b,&c);
add(a,b,c), add(b,a,c);
deg[a] ++, deg[b] ++;
}
for(int i = 1; i <= n; i ++)
{
s[i] = 0;
for(int j = 1; j <= m; j ++) f[i][j] = inf;
if(deg[i]==1) f[i][1] = 0, s[i] = 1;
}
dfs(1, -1);
printf("Case #%d: %lld\n",++cas,f[1][m]);
}
return 0;
}