import numpy as np
def safe_concatenate(arr_list, output_filename):
"""
np.memmap: 连接后的数组,使用memmap存储在磁盘上
"""
# 计算最终连接后数组的形状
total_length = sum(arr.shape[0] for arr in arr_list)
shape = (total_length, *arr_list[0].shape[1:])
dtype = arr_list[0].dtype
# 创建内存映射文件
memmap_array = np.memmap(output_filename, dtype=dtype, mode='w+', shape=shape)
# 分批将每个数组写入memmap文件
start_idx = 0
for arr in arr_list:
end_idx = start_idx + arr.shape[0]
memmap_array[start_idx:end_idx] = arr
start_idx = end_idx
del arr_list
memmap_array.flush()
return memmap_array
import tracemalloc
tracemalloc.start()
arr_list = [np.random.rand(100000, 100), np.random.rand(2000000, 100), np.random.rand(150000, 100)]
current, peak = tracemalloc.get_traced_memory()
print(f"Current memory usage: {current / 10**6} MB; Peak was {peak / 10**6} MB")
result = safe_concatenate(arr_list, 'output.dat')
current, peak = tracemalloc.get_traced_memory()
print(f"Current memory usage: {current / 10**6} MB; Peak was {peak / 10**6} MB")
tracemalloc.stop()
tracemalloc.start()
del arr_list, result
arr_list = [np.random.rand(100000, 100), np.random.rand(2000000, 100), np.random.rand(150000, 100)]
current, peak = tracemalloc.get_traced_memory()
print(f"Current memory usage: {current / 10**6} MB; Peak was {peak / 10**6} MB")
result = np.concatenate(result)
current, peak = tracemalloc.get_traced_memory()
print(f"Current memory usage: {current / 10**6} MB; Peak was {peak / 10**6} MB")
tracemalloc.stop()
可以尝试运行一下,第一种函数的运行内存所占用大小远小于np.concatenate