建议访问原文出处,获得更佳浏览体验。
原文出处:https://hyp1231.github.io/2018/09/02/20180902-2018nanjing-online-e/
题意
有 n n 个问题( ),解决一个问题花费 1 1 分钟,但解决问题 时,必须已经解决问题 pi,1,pi,2,…,pi,si p i , 1 , p i , 2 , … , p i , s i (记为问题 i i 的前驱问题)。当问题 是你解决的第 t t 个问题(在第 分钟)时,总分加上 ai⋅t+bi a i ⋅ t + b i 。求总分的最大值。
注意不一定要解决所有问题,且一个问题的前驱问题可能是它自身。
链接
题解
观察到问题的个数只有 20 20 ,我们考虑使用状态压缩,把某个问题的是否解决,映射到一个二进制数的某一位是否为 1 1 。比如如果已经解决了问题 ,我们可以用二进制数 00001101 00001101 表示。
题目要求得分的最大值,考虑使用状压 DP 解决。令
dp[i]
d
p
[
i
]
表示状态为
i
i
时的得分最大值,则状态转移方程为:
这里需要满足:
- next_state n e x t _ s t a t e 比 state s t a t e 仅仅多解决了一个问题 i i 。
- 问题 不被包含在状态 state s t a t e 中。
- 问题 i i 的前驱问题都已经被解决(即都包含在状态 中)。
- state s t a t e 是可以到达的合法状态。
为了解决条件
3
3
,我们可以预处理,将问题 的前驱问题压缩为一个状态
neighbori
n
e
i
g
h
b
o
r
i
,这样我们只需判断
neighbori
n
e
i
g
h
b
o
r
i
是否是
state
s
t
a
t
e
的子集(例
10100
10100
是
10111
10111
的子集,但
11000
11000
不是
10111
10111
的子集)即可。这个可以用位运算方便地计算,建议先思考一下。具体见代码中的 bool sub_set(int f, int s)
函数。
而状态转移方程中的 t t 其实就是 的二进制表示中的 1 1 的个数。我们可以提前计算出 中 1 1 的个数,再加 即为 t t 。
这样算法主过程即为:初始化 为 −1 − 1 表示非法状态,边界条件 dp[0]=0 d p [ 0 ] = 0 。之后外层从小到大枚举状态,当状态合法时,内层枚举问题,如果满足上述 4 4 个条件则进行状态转移。时间复杂度 。
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
typedef long long LL;
const int N = 21;
inline int count_one(int x) {
int sum = 0;
while (x) {
sum += x & 1;
x >>= 1;
}
return sum;
} // 计算 x 的二进制表示中有几个 1
inline bool sub_set(int f, int s) {
return (f | s) == f;
} // 如果 s 是 f 的子集(状态压缩),返回 true
int n, neighbor[N]; // neighbor[i] 表示第 i 个问题的前驱问题的状态压缩
LL a[N], b[N];
LL dp[1 << N], best;
int main() {
scanf("%d", &n);
int s, p;
for (int i = 1; i <= n; ++i) {
scanf("%lld%lld%d", &a[i], &b[i], &s);
for (int j = 0; j < s; ++j) {
scanf("%d", &p);
neighbor[i] |= 1 << (p - 1);
}
}
memset(dp, -1, sizeof(dp));
dp[0] = 0;
for (int st = 0; st < (1 << n); ++st) if (dp[st] != -1) { // 如果不是非法状态
int t = count_one(st) + 1; // 解决下一个问题的时间
for (int i = 1; i <= n; ++i)
if (((st >> (i - 1)) & 1) == 0) { // i 是还没有解决的问题
if (!sub_set(st, neighbor[i])) continue;
// 如果还没解决 i 的所有前驱问题,跳过
int next_st = st | (1 << (i - 1));
dp[next_st] = std::max(dp[next_st], dp[st] + t * a[i] + b[i]);
}
}
for (int st = 1; st < (1 << n); ++st) best = std::max(best, dp[st]);
printf("%lld\n", best);
return 0;
}