P r o b l e m \mathrm{Problem} Problem
有一个长度为 n 的 01 串,你可以每次将相邻的 k 个字符合并,得到一个新的字符并获得一定分数。得到的新字符和分数由这 k 个字符确定。你需要求出你能获得的最大分数。
1<=n<=300,0<=ci<=1,wi>=1,k<=8
S o l u t i o n \mathrm{Solution} Solution
首先我们知道,对于一串数中它的最优解长度一定是小于等于k的。
由于观察到 k k k的数据范围非常小,我们可以考虑用状态来完成。但是有长得像区间DP,我们可以考虑设置状态 f [ i ] [ j ] [ k ] f[i][j][k] f[i][j][k]表示区间 [ i , j ] [i,j] [i,j]最后经过合并后形成状态为 k k k的最大代价。
我们可以考虑对于每一个 i i i和 j j j,枚举中间点 m m m,使得 m m m到 j j j占意味, i i i到 m − 1 m-1 m−1占 k − 1 k-1 k−1位。
这样就能很容易得到转移方程:
f
[
i
]
[
j
]
[
s
<
<
1
]
=
max
(
f
[
i
]
[
m
−
1
]
[
s
]
+
f
[
m
]
[
j
]
[
0
]
)
f
[
i
]
[
j
]
[
s
<
<
1
∣
1
]
=
max
(
f
[
i
]
[
m
−
1
]
[
s
]
+
f
[
m
]
[
j
]
[
1
]
)
f[i][j][s<<1] = \max(f[i][m-1][s]+f[m][j][0])\\ f[i][j][s<<1|1] = \max(f[i][m-1][s]+f[m][j][1])
f[i][j][s<<1]=max(f[i][m−1][s]+f[m][j][0])f[i][j][s<<1∣1]=max(f[i][m−1][s]+f[m][j][1])
还有一种情况就是,当
r
−
l
+
1
=
k
r-l+1=k
r−l+1=k时,我们需要直接进行合并。我们可以这样实现:
C o d e \mathrm{Code} Code
#include <cstdio>
#include <cstring>
#include <iostream>
#define int long long
using namespace std;
const int K = 8;
const int N = 305;
int n, k, INF;
int f[N][N][(1<<K)+9], a[N], c[(1<<K)+9], w[(1<<K)+9];
int read(void)
{
int s = 0, w = 0; char c = getchar();
while (c < '0' || c > '9') w |= c == '-', c = getchar();
while (c >= '0' && c <= '9') s = s*10+c-48, c = getchar();
return w ? -s : s;
}
void DP(void)
{
memset(f,-30,sizeof f);
INF = f[0][0][0];
for (int i=1;i<=n;++i) f[i][i][a[i]] = 0;
for (int l=2;l<=n;++l)
for (int i=1;i+l-1<=n;++i) {
int j = i + l - 1;
int len = j - i;
while (len >= k) len -= (k-1);
for (int m=j;m>i;m-=k-1)
for (int s=0;s<1<<len;++s) {
f[i][j][s<<1] = max(f[i][j][s<<1],f[i][m-1][s]+f[m][j][0]);
f[i][j][s<<1|1] = max(f[i][j][s<<1|1],f[i][m-1][s]+f[m][j][1]);
}
if (len == k-1)
{
int g[2]; g[0] = g[1] = INF;
for (int s=0;s<1<<k;++s)
g[c[s]] = max(g[c[s]],f[i][j][s]+w[s]);
f[i][j][0] = g[0], f[i][j][1] = g[1];
}
}
int ans = INF;
for (int i=0;i<1<<k;++i)
ans = max(ans,f[1][n][i]);
cout<<ans<<endl;
}
signed main(void)
{
n = read(), k = read();
for (int i=1;i<=n;++i)
scanf("%1d", a+i);
for (int i=0;i<1<<k;++i)
c[i] = read(), w[i] = read();
DP();
return 0;
}