Address
https://www.lydsy.com/JudgeOnline/problem.php?id=4565
Solution
区间合并让人想到区间 dp ,而
k≤8
k
≤
8
又让人想到状压 dp 。
我们考虑合二为一。
f[l][r][S]
f
[
l
]
[
r
]
[
S
]
表示将区间
[l,r]
[
l
,
r
]
内的字符不断合并,最后变成串
S
S
的最大收益。
( 是一个长度为
(r−l)mod(k−1)+1
(
r
−
l
)
mod
(
k
−
1
)
+
1
的
01
01
串)
(由于每次合并会减少
k−1
k
−
1
个字符,故
S
S
的长度固定)
考虑 的每个字符,它们都是由原串的一个区间逐渐压缩成的。
故
S
S
的每个字符互相独立,互不影响。
我们就枚举一个 ,表示
S
S
的最后一个字符是由原串的区间 压缩成的。
这时候就有一个非常传统的区间 dp 转移了!
以下把
mg(S,x)
m
g
(
S
,
x
)
定义为
(S<<1)|x
(
S
<<
1
)
|
x
,即在
S
S
的后面插入 。
x∈{0,1}
x
∈
{
0
,
1
}
。
其中 x∈{0,1} x ∈ { 0 , 1 } 。
注意上面针对的是 |S|=(r−l)mod(k−1)+1<k−1 | S | = ( r − l ) mod ( k − 1 ) + 1 < k − 1 的情况。
如果 |S|=k−1 | S | = k − 1 ,那么 [l,mid] [ l , m i d ] 会和 (mid,r] ( m i d , r ] 组成一个长度为 k k 的串,还可以再次合并。
故当 时:
同样 x∈{0,1} x ∈ { 0 , 1 } 。
理论复杂度 O(2kn3) O ( 2 k n 3 ) ,但实际状态没有那么多。
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Rof(i, a, b) for (i = a; i >= b; i--)
#define Step(i, a, b, x) for (i = a; i <= b; i += x)
using namespace std;
inline int read() {
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
typedef long long ll;
const int N = 305, M = (1 << 8) + 5;
int n, k, a[N], c[M], w[M];
ll f[N][N][M >> 1];
int main() {
int i, j, mid, S;
n = read(); k = read();
For (i, 1, n) a[i] = read();
For (i, 0, (1 << k) - 1) c[i] = read(), w[i] = read();
memset(f, -1, sizeof(f));
For (i, 1, n) f[i][i][a[i]] = 0;
Rof (i, n, 1) For (j, i + 1, n) {
int orz = (j - i) % (k - 1) + 1;
if (orz > 1)
Step (mid, i + orz - 2, j - 1, k - 1)
For (S, 0, (1 << orz - 1) - 1) {
if (f[i][mid][S] == -1) continue;
if (f[mid + 1][j][0] != -1)
f[i][j][S << 1] = max(f[i][j][S << 1],
f[i][mid][S] + f[mid + 1][j][0]);
if (f[mid + 1][j][1] != -1)
f[i][j][(S << 1) | 1] = max(f[i][j][(S << 1) | 1],
f[i][mid][S] + f[mid + 1][j][1]);
}
else For (S, 0, (1 << k - 1) - 1) {
int tr0 = c[S << 1], tr1 = c[(S << 1) | 1];
Step (mid, i + k - 2, j - 1, k - 1) {
if (f[i][mid][S] == -1) continue;
if (f[mid + 1][j][0] != -1)
f[i][j][tr0] = max(f[i][j][tr0], f[i][mid][S]
+ f[mid + 1][j][0] + w[S << 1]);
if (f[mid + 1][j][1] != -1)
f[i][j][tr1] = max(f[i][j][tr1], f[i][mid][S]
+ f[mid + 1][j][1] + w[(S << 1) | 1]);
}
}
}
ll ans = -1;
For (i, 0, (1 << k) - 1) ans = max(ans, f[1][n][i]);
cout << ans << endl;
return 0;
}