题面
link
给定一个长为
n
(
n
≤
1
e
5
)
n (n ≤ 1e5)
n(n≤1e5) 的只含数字1 - 8的字符串,每出现一个逆序对
(
a
,
b
)
(a, b)
(a,b) (其中
b
<
a
b < a
b<a) 就会有
P
a
,
b
P_{a, b}
Pa,b 的 cost, 比如字符串
85511
85511
85511 的 cost 为
2
×
P
8
,
5
+
2
×
P
8
,
1
+
4
×
P
5
,
1
2 × P_{8, 5} + 2 × P_{8, 1} + 4 × P_{5, 1}
2×P8,5+2×P8,1+4×P5,1。
此外还有一个变换操作,可以花费
C
a
,
b
C_{a, b}
Ca,b 将所有的
a
a
a 换成
b
b
b,
b
b
b 换成
a
a
a,如花费
C
8
,
5
C_{8, 5}
C8,5(或者
C
5
,
8
C_{5, 8}
C5,8) 可以将字符串
85511
85511
85511 变成
58811
58811
58811。我们可以进行任意次的变换操作,最后要计算逆序对的花费,求总共最小的花费。
输入为长度
n
n
n,数字字符串,以及
8
×
8
8 × 8
8×8 的
P
P
P 矩阵和
C
C
C 矩阵,其中
P
P
P 矩阵是下三角矩阵(因为正序对的花费自然为0),
C
C
C 矩阵是对称矩阵。
分析
其实要求逆序对的花费,就是求各种数对的花费,因为正序对的花费都是 0。
由于变换操作是相同的全部数字一起变的,即一开始数字一样的位置,无论经过多少次变换,最后还是一样的;一开始不一样的数字,最后也肯定不一样的。那么字符串就可以最多分成 8 组, 第
i
i
i 组最开始是表示数字
i
i
i, 经过若干次变换之后,这一组可能会变成其他任何数字。
要计算的花费分为变换的花费和变换后数对的花费。
数对花费:若用
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 表示 第
i
i
i 组和第
j
j
j 组能组成的
(
i
,
j
)
(i, j)
(i,j)数对的个数,则没有变换的数对花费就是
∑
i
=
1
8
∑
j
=
1
8
d
p
[
i
]
[
j
]
×
P
[
i
]
[
j
]
\sum_{i = 1}^8\sum_{j = 1}^8 dp[i][j] ×P[i][j]
∑i=18∑j=18dp[i][j]×P[i][j], 若记第
i
i
i 组数最后变成
m
a
r
k
[
i
]
mark[i]
mark[i], 则花费就是
∑
i
=
1
8
∑
j
=
1
8
d
p
[
i
]
[
j
]
×
P
[
m
a
r
k
[
i
]
]
[
m
a
r
k
[
j
]
]
\sum_{i = 1}^8\sum_{j = 1}^8 dp[i][j] ×P[mark[i]][mark[j]]
∑i=18∑j=18dp[i][j]×P[mark[i]][mark[j]]。
可以通过线性时间计算
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]:
for(int i = 0; s[i]; i++) //计算组对的个数
{
int tmp = s[i] - '0';
for(int j = 1; j <= 8; j++)
dp[j][tmp] += num[j];
num[tmp]++;
}
变换花费:一开始各个组对应的序列就是
12345678
12345678
12345678, 此时的变换花费为0, 而最多有
8
!
8!
8!这么多序列可以变换,那该怎么计算到每个序列的最少花费呢?
一开始我用的是递归的方式,即每个序列可以通过变换两组的数字变成另一个序列,但是这样的话每一个序列可以有
C
8
2
C_8^2
C82 种变换,那么
8
!
8!
8! 种序列的开销太大了。比如从
12345678
12345678
12345678,变成
87654321
87654321
87654321,有很多种方式,通过递归找到最短的变换方式并不是好的选择。
由此引入了建图的思想,将这个问题考虑成一张图,序列为点,其中相邻点(可以通过一次操作变换得到)之间的边即为变换的花费,那么我们要求的即是最初序列到其他序列的最小花费,即为单源最短路问题了。
如果用Dijkstra 方法计算,边数为
8
!
×
C
8
2
2
\frac{8!×C_8^2}{2}
28!×C82, 点数即为
8
!
8!
8!,而怎么将一个序列映射到点的标号呢?这里可以用生成排列的序号,如
12345678
12345678
12345678 是
0
0
0,
12345687
12345687
12345687是
1
1
1,
87654321
87654321
87654321是
8
!
−
1
8!-1
8!−1 :
int GetIndex(int* a) //获取一个排列的序号,从0开始
{
int res = 0;
for(int i = 1; i < 8; i++)
{
int cnt = 0;
for(int j = i + 1; j <= 8; j++)
if(a[i] > a[j])
cnt++;
res += cnt * fac[8 - i]; //fac[i]表示阶乘
}
return res;
}
综上,只需要预先计算序列之间的变换花费 ( O ( 8 ! × C 8 2 2 l g 8 ! ) ) O(\frac{8!×C_8^2}{2}lg8!)) O(28!×C82lg8!)), 预处理组对数( O ( 8 n ) O(8n) O(8n)), 对每个序列计算数对花费( O ( 8 ! × 64 ) O(8!× 64) O(8!×64))。
代码
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ll;
const int maxn = 1e5 + 10;
const int INF = 0x3f3f3f3f;
const double eps = 1e-8;
int fac[10], n, dp[9][9], num[9];
ll P[9][9], C[9][9];
char s[maxn];
int GetIndex(int* a) //获取一个排列的序号,从0开始
{
int res = 0;
for(int i = 1; i < 8; i++)
{
int cnt = 0;
for(int j = i + 1; j <= 8; j++)
if(a[i] > a[j])
cnt++;
res += cnt * fac[8 - i];
}
return res;
}
struct node
{
int a[9];
ll w;
bool operator<(const node& m) const //使得优先队列可以小顶堆
{
return w > m.w;
}
};
ll d[41000];
priority_queue<node> que;
void Dijkstra()
{
memset(d, INF, sizeof(d));
d[0] = 0; //初始位置为序列12345678
node u;
u.w = 0;
for(int i = 1; i <= 8; i++)
u.a[i] = i;
que.push(u);
while(!que.empty())
{
node p = que.top(); que.pop();
int k1 = GetIndex(p.a);
if(d[k1] < p.w) continue; //去掉已经被访问过的节点和被更新过的边长
node tmp = p;
for(int i = 1; i < 8; i++) //遍历与这个序列相邻的序列,进行更新
for(int j = i + 1; j <= 8; j++)
{
swap(tmp.a[i], tmp.a[j]);
int k2 = GetIndex(tmp.a);
if(d[k2] > d[k1] + C[i][j])
{
d[k2] = d[k1] + C[i][j];
tmp.w = d[k2];
que.push(tmp);
}
swap(tmp.a[i], tmp.a[j]);
}
}
}
int main()
{
fac[1] = 1; //计算阶乘
for(int i = 2; i <= 8; i++)
fac[i] = fac[i-1] * i;
scanf("%d %s", &n, s);
for(int i = 1; i <= 8; i++)
for(int j = 1; j <= 8; j++)
scanf("%lld", &P[i][j]);
for(int i = 1; i <= 8; i++)
for(int j = 1; j <= 8; j++)
scanf("%lld", &C[i][j]);
Dijkstra(); //计算序列之间转换的花费
for(int i = 0; s[i]; i++) //计算组对的个数
{
int tmp = s[i] - '0';
for(int j = 1; j <= 8; j++)
dp[j][tmp] += num[j];
num[tmp]++;
}
int mark[9] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; //对每一个序列计算答案
ll ans = LLONG_MAX;
do
{
ll res = d[GetIndex(mark)];
for(int i = 1; i <= 8; i++)
for(int j = 1; j <= 8; j++)
res += dp[i][j] * P[mark[i]][mark[j]];
ans = min(res, ans);
}while(next_permutation(mark + 1, mark + 8 + 1));
printf("%lld\n", ans);
}
虽然题目给的n 是 1e5, 不过要开到 1e6,奇奇怪怪
收获
① 建图的思想,长姿势了,瞬间将一个复杂的递归转成一个带log的线性做法。
②复习了离散数学中学的生成排列,以及C++ next_permutation的用法