01背包
问题描述
给定一个容积为V的背包,现在有n件物品,第i件物品的体积为w i,价值为vi,每件物品只能拿或者不拿,请求出体积总和不超过V的最大价值。
解题思路
状态
dp[i][j]表示前i件物品,体积为j时的最大价值。
状态转移
对于第i件物品,且第i件物品的体积比j大时,第i件物品一定不拿。
对于第i件物品,且第i件物品的体积比j小时,可能有拿or不拿两种状态。
- 拿:前i件物品体积为j由前i-1件物品体积减掉第i件物品的体积(由于前i件物品的体积加上第i件物品的体积才是j)的状态转移而来
-
dp[i-1][j-w[i]]=>dp[i][j]
- 不拿:前i件物品体积为j由前i-1件物品体积也为j转移而来(因为第i件物品不拿,所以体积不变)
-
dp[i-1][j]=>dp[i][j]
通过在限制体积下比较拿以及不拿哪种情况价值更大来不断推进。
边界条件
当件数i或者体积j任一为0时,最大价值一定为0。
即dp[0][0]=dp[…][0]=dp[0][…]=0
动态规划转移方程
d p [ i ] [ j ] = { d p [ i − 1 ] [ j ] , j < w [ i ] m a x ( d p [ i − 1 ] [ j ] , d p [ i − 1 ] [ j − w [ i ] ] + v [ i ] ) , j > = w [ i ] dp[i][j]=\begin{cases} dp[i-1][j],j<w[i]\\ max(dp[i-1][j],dp[i-1][j-w[i]]+v[i]),j>=w[i] \end{cases} dp[i][j]={dp[i−1][j],j<w[i]max(dp[i−1][j],dp[i−1][j−w[i]]+v[i]),j>=w[i]
代码实现
n,v=map(int,input().split())
a=[[0]*(n+1)]
for i in range(n):
a.append([0]+list(map(int,input().split())))
dp=[[0]*(v+1) for _ in range(n+1)]
for i in range(1,n+1):
for j in range(0,v+1):
if a[i][1]<=j:
dp[i][j]=max(dp[i-1][j],dp[i-1][j-a[i][1]]+a[i][2])
else:
dp[i][j]=dp[i-1][j]
print(dp[n][v])
对于前面的数据输入部分,也可以采用:
n,v=map(int,input().split())
for i in range(1,n+1):
wi,vi=map(int,input().split())
for j in range(0,v+1):
if wi<=j:
dp[i][j]=max(dp[i-1][j],dp[i-1][j-wi]+vi)
else:
dp[i][j]=dp[i-1][j]
print(dp[n][V])
滚动数组优化
在上述解法中,我们使用的dp数组占用空间为(v+1)(n+1),为了节省空间,可以使用滚动数组优化。
由于前i个物品都是由前i-1个物品转化而来,两者的奇偶性一定不同,可以使用长度为2的滚动数组优化空间。
长度为2的滚动数组
dp[(i%2)-1][j-w[i]] | … | dp[(i%2)-1][j] |
---|---|---|
… | dp[i%2][j] |
把数组按照奇偶性不同分成两行,每一次都在上一组数组的基础上运算。
代码实现
n,V=map(int,input().split())
dp=[[0]*(V+1) for _ in range(2)]
for i in range(1,n+1):
w,v=map(int,input().split())
for j in range(V+1):
if j<w:
dp[i%2][j]=dp[(i-1)%2][j]
else:
dp[i%2][j]=max(dp[(i-1)%2][j],dp[(i-1)%2][j-w]+v)
print(dp[n%2][V])
长度为1的滚动数组
dp[i-1][j-w[i]] | … | dp[i-1][j] |
---|---|---|
… | dp[i][j] |
如果把这里第一行数组计算出来的结果不是放到另一行,而是直接覆盖在原来的数组,就是长度为1的滚动数组。
解题思路
更新dp[i][j]时,用到上一行对应位置dp[i-1][j]和上一行先前位置dp[i-1][j-w[i]]的元素,因此可以使用单个数组进行更新,直接从大到小对dp数组进行覆盖即可。
代码实现
n,v=map(int,input().split())
dp=[0]*(v+1)
for i in range(1,n+1):
w,v=map(int,input().split())
for j in range(v,w-1,-1):
dp[j]=max(dp[j],dp[j-w]+v)
print(dp[v])