题意
给你一个文本串 S ( l e n ( S ) ≤ 1 0 5 ) S(len(S) \leq 10^5) S(len(S)≤105),由前 m ( m ≤ 20 ) m(m\leq20) m(m≤20)个小写字母组成。要你求一种键盘的排列,使得打出这个文本串的消耗最小。这个消耗 c o s t cost cost的计算方式为相邻字符的键盘距离之和,也就是 ∑ i = 2 n ∣ p o s s i − 1 − p o s s i ∣ \sum_{i=2}^{n} |pos_{s_i-1}-pos_{s_i}| ∑i=2n∣possi−1−possi∣。只需要输出这个最小消耗就行。
解题思路
打比赛时候以为是一道贪心题,结果怎么做都不对,后来看到
(
m
≤
20
)
(m\leq20)
(m≤20)就在想状压DP,但是完全不会做这种对于排列的DP。因为无法确定这个字符应该插入到什么位置,也就没法计算
c
o
s
t
cost
cost。
赛后看题解才知道原来DP还可以有预先计算
c
o
s
t
cost
cost这种操作,那么公式也就很好想了。
我们先预处理文本串
S
S
S中相邻字符对出现的次数。设
f
(
s
)
f(s)
f(s)为当前使用了集合为
s
s
s的字符进行排列,所能产生的最小
c
o
s
t
cost
cost。当我们加入一个新字符的时候,我们除了计算当前最小值,还要把还没有使用过的字符与使用过的字符所产生的
c
o
s
t
cost
cost也加进去,由此可得:
f
(
s
)
=
min
{
f
(
s
−
{
j
}
)
}
[
j
∈
s
]
f
(
s
)
=
f
(
s
)
+
c
o
s
t
(
i
,
j
)
[
i
∈
s
,
j
∈
s
‾
]
\begin{aligned} f(s) &=\min\{f(s-\{j\})\} & [j\in s]\\ f(s) &= f(s) + cost(i, j) & [i \in s,j\in \overline{s}]\\ \end{aligned}
f(s)f(s)=min{f(s−{j})}=f(s)+cost(i,j)[j∈s][i∈s,j∈s]
那么整体最小值就是
f
(
全
集
)
f(全集)
f(全集)了。
顺便吐槽一下状压
O
(
m
2
2
m
)
O(m^22^m)
O(m22m)都
4
×
1
0
8
4\times10^8
4×108了,结果
C
F
CF
CF神机居然半秒就跑出来了也是可以,我还以为会超时。。
时间复杂度
O ( m 2 2 m + S ) O(m^22^m+S) O(m22m+S)
代码
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <list>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
using namespace std;
typedef long long ll;
const int INF = 2147483647;
const int INF2 = 0x3f3f3f3f;
const ll INF64 = 1e18;
const double INFD = 1e30;
const double EPS = 1e-6;
const double PI = 3.1415926;
const ll MOD = 1e9 + 7;
int n, m, k;
int CASE;
const int MAXN = 300005;
char text[MAXN];
// 记录点对的数量
int cnt[25][25];
ll dp[(1 << 20) + 5];
int main() {
#ifdef LOCALLL
freopen("in", "r", stdin);
freopen("out", "w", stdout);
#endif
scanf("%d%d", &n, &k);
scanf("%s", text);
map<pair<char, char>, int> pairs;
for (int i = 0; i < n - 1; i++) {
char a = text[i];
char b = text[i + 1];
if (a == b) continue;
cnt[a - 'a'][b - 'a']++;
cnt[b - 'a'][a - 'a']++;
}
memset(dp, 0x3f, sizeof(dp));
dp[0] = 0;
for (int s = 1; s < (1 << k); s++) {
for (int i = 0; i < k; i++) {
if ((s >> i) & 1) {
dp[s] = min(dp[s], dp[s ^ (1 << i)]);
}
}
for (int i = 0; i < k; i++) {
if (!((s >> i) & 1)) continue;
for (int j = 0; j < k; j++) {
// 未访问的点的贡献
if ((~s >> j) & 1) {
dp[s] += cnt[i][j];
}
}
}
}
printf("%lld\n", dp[(1 << k) - 1]);
return 0;
}