题目来自于 codewars 4kyu 题
题目:
使用 python 实现 Burrows-Wheeler 变换。
Burrows-Wheeler 变换是一种压缩算法,也被称为块排序压缩,压缩思路是对一个字符串进行处理,使其不改变字符,只改变字符的顺序,尽可能的将字符串中重复的多个子串放在一起,便于使用基于处理字符串中连续重复字符的技术(如MTF变换和游程编码)的编码更容易被压缩。(参考自百度百科)
对于一个字符串 "bananabar",先将其进行全循环排列:
Input: "bananabar"
b a n a n a b a r
r b a n a n a b a
a r b a n a n a b
b a r b a n a n a
a b a r b a n a n
n a b a r b a n a
a n a b a r b a n
n a n a b a r b a
a n a n a b a r b
然后对全循环排列后的矩阵进行按照字典排序,将排序后的矩阵的最后一列和原字符串所在行的索引作为输出。
.-.
a b a r b a n a n
a n a b a r b a n
a n a n a b a r b
a r b a n a n a b
b a n a n a b a r <- 4
b a r b a n a n a
n a b a r b a n a
n a n a b a r b a
r b a n a n a b a
'-'
Output: ("nnbbraaaa", 4)
题目的目标是写一个 encode 函数和一个 decode 函数,实现Burrows-Wheeler变换。(注意:存在空的字符串输入,其行号将被忽略)
个人思路:
encode 函数:
先通过题目中所提示的方式,构建全循环排列矩阵。
通过使用切片的方式,将字符串从一个位置切开,将左右两边颠倒后连接。
[s[-i:]+s[:-i] for i in range(len(s))]
对于生成的全循环排列矩阵进行字典排序
l = sorted(s[-i:]+s[:-i] for i in range(len(s)))
根据排序后的全循环排列矩阵,进行函数的输出
输出为一个元组,第一个元素是排序后的全循环排列矩阵的最后一列构成的字符串,通过索引和循环完成。
"".join(i[-1] for i in l)
第二个元素是原本的字符串在排序后的全循环排列矩阵中的行索引。
l.index(s)
最后加入对于空字符串输入的特殊处理,encode 函数如下
def encode(s):
if s=='':
return '', 0
l = sorted(s[-i:]+s[:-i] for i in range(len(s)))
return "".join(i[-1] for i in l), l.index(s)
decode 函数:
Burrows-Wheeler变换的逆变换过程如下:
将编码串排在矩阵的第一列,然后进行字典排序。将排序后的矩阵作为新的矩阵。将编码串再次排在矩阵的第一列,然后再进行字典排序,直到遍历完整个编码串。
Input: ("nnbbraaaa", 4)
n a
n a
b a
b a
r sort=> b
a b
a n
a n
a r
na ab
na an
ba an
ba ar
rb sort=> ba
ab ba
an na
an na
ar rb
nab aba
nan ana
ban ana
bar arb
rba sort=> ban
aba bar
ana nab
ana nan
arb rba
naba abar
nana anab
bana anan
barb arba
rban sort=> bana
abar barb
anab naba
anan nana
arba rban
.
.
.
.
.
.
.
.
.
.
a b a r b a n a n
a n a b a r b a n
a n a n a b a r b
a r b a n a n a b
b a n a n a b a r
b a r b a n a n a
n a b a r b a n a
n a n a b a r b a
r b a n a n a b a
不难看出我们获得了原本的排序后的全循环排列矩阵,然后取索引为 4 的行,就是我们想要的原字符串了。
构建空矩阵
l=['']*len(s)
循环,通过上述的逆转换方法,获得原本的排序后的全循环排列矩阵
将 l 的的第一列前面加上编码串作为新的第一列,原本的第一列右移
for _ in range(len(s)):
l=sorted(s[i]+l[i]for i in range(len(s)))
对空字符串输入进行特殊处理,输出结果
def decode(s, n):
if not s:return s
l=['']*len(s)
for _ in range(len(s)):
l=sorted(s[i]+l[i]for i in range(len(s)))
return l[n]