我想从itertools.product的结果构建一个numpy数组。我的第一种方法很简单:
from itertools import product
import numpy as np
max_init = 6
init_values = range(1, max_init + 1)
repetitions = 12
result = np.array(list(product(init_values, repeat=repetitions)))
此代码对于"小" repetitions(如<= 4)效果很好,但是对于"大"值(> = 12),它将完全占用内存并崩溃。我以为建立列表就是吃掉所有RAM的事情,所以我研究了如何直接使用数组来建立它。我发现Numpy等效于itertools.product,并使用numpy构建两个数组的所有组合的数组。
因此,我测试了以下替代方案:
备选方案1:
results = np.empty((max_init**repetitions, repetitions))
for i, row in enumerate(product(init_values, repeat=repetitions)):
result[:][i] = row
备选方案2:
init_values_args = [init_values] * repetitions
results = np.array(np.meshgrid(*init_values_args)).T.reshape(-1, repetitions)
备选方案3