0、背景
【目的】:现有一个数据条数很大(约 25w)的 numpy array:allData。想根据字典 aDict 里的值选取出其中的某些数据(约 2w)条组成一个新的 numpy array:data.
原代码大概的逻辑如下:(只放出了和我猜想和结论相关的部分)
data = np.array([])
for key, value in aDict.items():
data = np.concatenate([data, allData[key]])
1、问题
代码到后面运行速度变得越来越慢
2、猜测 1:是不是由于字典数据量太大造成的?
【排除方法】:将 for 循环里面的代码注释掉,看运行时间。
【结论】:注释掉之后运行速度很快,所以问题应该出在被注释掉的 for 循环里面的语句。
3、正确推理 & 分析原因
定位到为题所在之后,搜了一下 np.concatenate() 的原理:python - Concatenate Numpy arrays without copying - Stack Overflow
numpy array 的内存必须是连续的,所以进行 np.concatenate() 时,相当于重新分配了一个大数组,再把要 concat 起来的小数组里的值全部 copy 进去。
所以耗时的原因在这里,在 for 循环里 concatenate,相当于每次都要重新分配并复制,越到后面,需要复制的值越多,所以就越慢。
4、解决方案
在进入 for 循环之前,先分配好最终 numpy array(data)需要的所有空间。
在 for 循环里,直接往 numpy array(data)里赋值即可。
代码如下:
data = np.empty(shape=(20000, ..., ..., ..., ...))
count = 0
for key, value in aDict.items():
data[count, :, :, :, :] = allData[idx, :, :, :, :]
count += 1