前两天微博上有人讨论洗牌程序,没细看内容但感觉似乎有点意思,今天自己尝试一下。
所谓洗牌程序就是把一个序列的元素位置打乱,这在 Python 里有一个标准函数:random.shuffle
。在开始动手之前我们先简化描述一下需求:
洗牌后的每个元素随机出现在每个位置,且概率相同
从这个结论也可以推导出一点:
元素洗牌后的位置与洗牌前无关
第一点是第二点的充分条件,因此测试函数只需要测第一点就够了。测试方法为重复执行洗牌函数,统计每个位置每个元素出现的次数。这样测试函数会输出一个矩阵,如果我们不考虑针对性捣乱情况的话,测试函数可以简化为计算每个位置的值的和,这个和的方差能够近似体现洗牌函数的正确性。
def test(f, n=10000):
dataset = [list(range(10)) for i in range(n)]
for row in dataset:
f(row)
stat = []
for i in range(10):
stat.append(sum([row[i] for row in dataset]))
return stat
先测试一下标准函数的结果:
In [72]: test(random.shuffle)
Out[72]: [44909, 44468, 45506, 45184, 45086, 44883, 44574, 45145, 44656, 45589]
然后我们分析一下这个问题:
前提:在本文中不考虑伪随机数的问题,并且除了 random.random 外不能使用其他函数
每个元素分配一个随机位置的话,等于说每个元素的分配过程应该是互相独立的,与现在位置无关的。因此最简单的方法就是为每个元素分配一个随机数,然后按随机数的值进行排序这样他们的位置就该是完全随机的。
def myshuffle(array):
sort_map = {x: random.random() for x in array}
new_array = sorted(array, key=lambda x: sort_map[x])
for i in range(len(array)):
array[i] = new_array[i]
测试结果:
In [88]: test(myshuffle)
Out[88]: [45243, 44605, 45217, 45030, 45129, 44362, 45246, 45033, 45037, 45098]
据说洗牌程序业界有一个标准的 fisher_yates
算法,翻译成 Python 是这样的:
def fisher_yates(array):
for i in reversed(range(1, len(array))):
j = int(random() * (i + 1)) # j = random.choice(list(range(i + 1)))
array[i], array[j] = array[j], array[i]
即从待处理序列中随机抽出一个元素放到队尾,然后将待处理序列的尾部边界向前挪一位(如果 j 的生成不好理解就看注释的那行等价代码)。因为最后剩一个元素的时候没必要再抽,所以这个算法比上面的 myshuffle 少进行一次 random 运算,而且因为是直接调换位置,空间消耗小得多,还少了一次排序计算。
测试结果:
In [148]: test(fisher_yates)
Out[148]: [44996, 45089, 45322, 44926, 44888, 45023, 44896, 45407, 44873, 44580]
这个算法也是 Python 内建的 random.shuffle
使用的方法。
那么最后一个问题,是否还存在其他的算法,使得调用 random.random
的次数更少呢?即,是否存在算法,使序列长度为 N(N≧3) 时调用随机数生成函数的次数 k ≦ N-2?
回想随机数的用法的话,会发现我们通常都是用它来生成一个样本空间的随机下标的。即使简单如 N=2 的情况:
x = [1, 0]
def shuffle_2(x):
assert len(X) == 2
if random.random() < 0.5:
x.reverse()
我们做这个小于 0.5 的判断,其本质也是在大小为 2 的空间里选择下标。如果要使用一个随机数洗牌 3 个以上(含)的元素的话,我们就需要构建一个空间,空间里的每一个元素都含有一种唯一的元素排列形式。即该空间为序列的全排列。
def wired_shuffle(array):
all_permutations = some_function(array)
rand_index = int(len(all_permutations) * random.random())
new_array = all_permutations[rand_index]
for i in range(len(array)):
array[i] = new_array[i]
好,现在问题转化了。通过这种方式我们节省了随机函数的调用时间,却不得不生成一次序列的全排列。这带来了两个问题:
- 是否存在一种方法可以直接从下标计算出某种排列,而非全部生成一遍
- 随机函数返回的浮点数是有限的,那么这个算法能处理的序列长度也就是有限的。
其实在现有的随机函数实现下,问题2 就已经判了这个算法死刑了。但为了好玩,我们继续思考一下问题1。上面图省事没有去实现这个 some_function
,现在不得不先实现一下看看逻辑:
def some_function(array):
if not array:
return [[]]
all_permutations = []
for i, key in enumerate(array):
remaining = array[:i] + array[i + 1:]
all_permutations.extend([[key] + _array for _array in some_function(remaining)])
return all_permutations
(因为是洗牌程序用的,所以无需去重,默认每个元素都不一样就可以了。)
那么问题1的答案可以是:
def the_permutation(array, index):
_array = copy.deepcopy(array)
for i in range(len(_array)):
left_permutation_count = math.factorial(len(_array) - i - 1)
j = (index // left_permutation_count) # 计算这一位的系数
_array.insert(i, _array.pop(j + i))
index -= (j * left_permutation_count)
return _array
好了,现在我们有了一个新的 wired_shuffle:
def new_wired_shuffle(array):
index = int(math.factorial(len(array)) * random.random())
new_array = the_permutation(array, index)
for i in range(len(array)):
array[i] = new_array[i]
整合简化一下:
def new_wired_shuffle(array):
index = int(math.factorial(len(array)) * random.random())
for i in range(len(array)):
left_permutation_count = math.factorial(len(array) - i - 1)
j = (index // left_permutation_count)
array.insert(i, array.pop(j + i))
index -= (j * left_permutation_count)
测试一下:
In [171]: test(new_wired_shuffle)
Out[171]: [44967, 45259, 44984, 45141, 44940, 44820, 45165, 44865, 44854, 45005]
这次我们调用 random 的次数缩减到了 1 次,却增加了 N 次 math.factorial
调用。阶乘函数的速度会随着 N 变大而越来越慢,且 insert/pop
也比直接赋值要慢。所以,这个函数的性能到底怎么样呢?
我们拿他和 random.shuffle
对比一下(通过 timeit
):
In [200]: for N in (2, 4, 6, 8, 12, 15, 25, 50):
...: print('%d' % N)
...: compare([random.shuffle, new_wired_shuffle], N)
...:
2
shuffle: 0.020176645999526954
new_wired_shuffle: 0.032341828000426176
4
shuffle: 0.03858100000070408
new_wired_shuffle: 0.04256904300018505
6
shuffle: 0.056119916998795816
new_wired_shuffle: 0.047379309000461944
8
shuffle: 0.07575537699995039
new_wired_shuffle: 0.06619280999984767
12
shuffle: 0.09467284899983497
new_wired_shuffle: 0.09349796399874322
15
shuffle: 0.11433432200101379
new_wired_shuffle: 0.11504831499951251
25
shuffle: 0.1927355459993123
new_wired_shuffle: 0.21537912900021183
50
shuffle: 0.35780333500042616
new_wired_shuffle: 0.535689784999704
N 在 [6, 12] 区间内小胜!耶~