一、什么是01背包问题?
举个例子,你要去一个水果摊拿水果,每种水果都有对应的两种属性:占用的体积V和蕴含的价值W。而你的背包体积为N。
老板说:每种水果只能拿一个!因此对于咱们肯定得想一种搭配方式使得拿的水果总体积不超过背包容积,但是价值总和达到最大!
二、01背包例题
第一行两个整数,N,V,用空格隔开,分别表示物品数量和背包容积。
接下来有 N 行,每行两个整数 vi,wi,用空格隔开,分别表示第 i 件物品的体积和价值。
数据范围:
0<N,V≤1000
0<vi,wi≤1000
输入样例:
4 5
1 2
2 4
3 4
4 5
三、推导状态方程
- 设物品数N = 4,背包容量V = 5;因此声明一个 4 * 5 的数组, 生成以下状态表。
- 按照物品顺序更新状态表
存在状态数组DP[i][j](0<= i<=N, 0<=j <=V),表示存在i件物品时,背包容量j对应的最大价值;
1)物品1时,
当 j = 0 ,f[1][0] = 0;
当1 <= j <= V ,f[1][j] = 2;
2)物品2时,
当 j = 0,f[2][0] = 0;
当1 <= j < 2,无法放入物品2,使用前一物品1的状态,f[2][j] =f[1][j];
当2 <= j <= V,可以放入物品2,考虑放入和不放入的情况,
当V = 2时,可以选择放入物品2或者不放入物品2,f[2][2] = max(2, 4);
当V = 3时,可以选择放入物品2或者不放入物品2,f[2][2] = max(2, 6);
以上可以总结出j>=2时, f[2][j] = max(f[1][j],f[1][j - v[2]] + w[2]);
3) 物品3及以上时,
j < w[j]时,f[i][j] = f[i - 1][j]
j>=w[j]时,f[i][j] = max(f[i - 1][j],f[i - 1][j - v[i]] + w[i]);
四、代码实现
#include <malloc.h>
#include <stdio.h>
#include <string.h>
int main() {
int arr[1000][2];
int v, n, *status[1000];
while (scanf("%d %d", &n, &v) != EOF) {
for (int i = 1; i <= n; i++) { // 任意物品i都需要依赖前一物品i-1的状态,为减少后续DP过程的计算,物品和状态都从1开始
scanf("%d %d", &arr[i][0], &arr[i][1]);
status[i] = (int*)malloc(sizeof(int) * (v + 1));
}
status[0] = (int*)malloc(sizeof(int) * (v + 1));
memset(status[0], 0x00, sizeof(int) * (v + 1)); // 初始化物品0的状态
for (int i = 1; i <= n; i++) {
status[i][0] = 0;
for (int j = 1; j <= v; j++) {
status[i][j] = status[i - 1][j];
if (j >= arr[i][0]) {
status[i][j] =
status[i][j] > (status[i - 1][j - arr[i][0]] + arr[i][1])
? status[i][j]
: (status[i - 1][j - arr[i][0]] + arr[i][1]);
}
}
}
printf("%d\n", status[n][v]);
}
}
五、优化
从状态的推导图中,可以看出以下2点:
- 只使用了status[i]和status[i-1]这2层状态,实际空间复杂度 = n;
- 当j<v时,继承status[i-1][j];
- 当j>=v时, 只使用了上一层的status[i-1][j],status[i-1][j - v[i]];
针对以上3点,做出以下优化:
- 申请status[v],存储最新的状态
- 当j<v时,保留上一次的状态
- 当j>=v时, f[j] = max(f[j],f[j - v[i]] + w[i]);
- 由于f[j]会依赖f[j - v[i]]的状态,需要修改更新status[j]的顺序,从j=v开始,保证f[j - v[i]]在f[j]之后更新。
#include <stdio.h>
int main() {
int arr[1000][2];
int v, n, status[1000] = {0};
while (scanf("%d %d", &n, &v) != EOF) {
for (int i = 1; i <= n; i++) {
scanf("%d %d", &arr[i][0], &arr[i][1]);
}
for (int i = 1; i <= n; i++) {
for (int j = v; j >= arr[i][0]; j--) {
status[j] =
status[j] > (status[j - arr[i][0]] + arr[i][1])
? status[j]
: (status[j - arr[i][0]] + arr[i][1]);
}
}
printf("%d\n", status[v]);
}
}