题目链接
P1273 有线电视网
分析
有一定难度的树形
d
p
dp
dp,本质上是树上分组背包问题。对于树上任意一个非叶子节点的点
x
x
x,考虑取其子树中
j
j
j个叶子节点所能获得的最大利润,而这
j
j
j个叶子节点可以从
x
x
x的儿子节点取得。因此令
d
p
[
x
]
[
i
]
[
j
]
dp[x][i][j]
dp[x][i][j]表示从节点
x
x
x的前
i
i
i个儿子节点取
j
j
j个用户所能获得的最大利润,那么答案为使
d
p
[
1
]
[
s
o
n
[
1
]
]
[
j
]
≥
0
dp[1][son[1]][j] \ge 0
dp[1][son[1]][j]≥0的最大的
j
j
j值,其中
s
o
n
[
1
]
son[1]
son[1]为节点
1
1
1的儿子数,有递推方程
d
p
[
x
]
[
i
]
[
j
]
=
max
(
d
p
[
x
]
[
i
−
1
]
[
j
]
,
d
p
[
x
]
[
i
−
1
]
[
j
−
k
]
+
d
p
[
v
]
[
s
o
n
[
v
]
]
[
k
]
−
w
)
dp[x][i][j] = \max (dp[x][i - 1][j],dp[x][i - 1][j - k] + dp[v][son[v]][k] - w)
dp[x][i][j]=max(dp[x][i−1][j],dp[x][i−1][j−k]+dp[v][son[v]][k]−w),其中
v
v
v为
x
x
x的儿子节点,
w
w
w为
x
x
x、
v
v
v之间的花费。
上面
d
p
dp
dp数组会
M
L
E
MLE
MLE,考虑优化掉
i
i
i这一维,得到转移方程
d
p
[
x
]
[
j
]
=
max
(
d
p
[
x
]
[
j
]
,
d
p
[
x
]
[
j
−
k
]
+
d
p
[
v
]
[
k
]
−
w
)
dp[x][j] = \max (dp[x][j],dp[x][j - k] + dp[v][k] - w)
dp[x][j]=max(dp[x][j],dp[x][j−k]+dp[v][k]−w),注意此时应该倒着枚举
j
j
j(可以思考一下为什么)。初始化时令
{
d
p
[
i
]
[
0
]
=
0
,
i
∈
[
1
,
n
−
m
]
d
p
[
i
]
[
1
]
=
p
,
i
∈
[
n
−
m
+
1
,
n
]
d
p
[
i
]
[
j
]
=
−
inf
,
o
t
h
e
r
s
\left\{ \begin{array}{l} dp[i][0] = 0,\;\;\;i \in \left[ {1,n - m} \right]\\dp[i][1] = p,\;\;\;i \in \left[ {n - m + 1,n} \right]\\dp[i][j] = - \inf ,\;\;\;others\end{array} \right.
⎩⎨⎧dp[i][0]=0,i∈[1,n−m]dp[i][1]=p,i∈[n−m+1,n]dp[i][j]=−inf,others,那么最终的答案为使
d
p
[
1
]
[
[
j
]
≥
0
dp[1][[j] \ge 0
dp[1][[j]≥0的最大
j
j
j值。
代码
#include<bits/stdc++.h>
#define ll long long
#define FULL(x,y) memset(x,y,sizeof(x))
#define pb push_back
using namespace std;
const int N=3005,inf=-1e9;
int n,m,cnt;
int head[N],p[N],dp[N][N],sz[N];
struct edge {
int v,w,next;
}e[N];
void add(int u, int v, int w) {
e[++cnt].v=v;
e[cnt].w=w;
e[cnt].next=head[u];
head[u]=cnt;
}
int dfs(int x) {
int fl=0;
for(int i=head[x];i;i=e[i].next) {
int v=e[i].v;
fl=1;
sz[x]+=dfs(v);
for(int j=m;j>=1;j--) {
for(int k=1;k<=sz[v];k++) {
if (j-k>=0) dp[x][j]=max(dp[x][j],dp[x][j-k]+dp[v][k]-e[i].w);
}
}
}
if (!fl) return sz[x]=1;
return sz[x];
}
int main() {
cin>>n>>m;
int a,c,k;
for(int i=1;i<=n-m;i++) {
cin>>k;
for(int j=1;j<=k;j++) {
cin>>a>>c;
add(i,a,c);
}
}
for(int i=1;i<=m;i++) cin>>p[i];
for(int i=1;i<=n;i++) {
for(int j=0;j<=m;j++) dp[i][j]=inf;
}
for(int i=1;i<=n-m;i++) dp[i][0]=0;
for(int i=n-m+1;i<=n;i++) dp[i][1]=p[i-n+m];
int ans=0;
dfs(1);
for(int i=0;i<=m;i++) {
if (dp[1][i]>=0) ans=i;
}
cout<<ans;
return 0;
}