在使用并行计算的时候希望维护同一个变量,比如将高分辨率的全球数据(例如30m)重采样为0.25度的数据,全球(720,1440),原始数据是10×10度的单个文件(存到单一文件太大了),全球的话就是18×36=648个文件,如果用18个进程并行的话只能将这些文件分成36组,每个进程负责36个文件的重采样,保存到(720,1440)的数组,这样每个进程都会输出一个(720,1440)的数组,最后我们还需要将这些数据合并成一个数组,还是很麻烦的,这里给出共享内存的方法,让所有进程直接对共享内存中的数组进行读写,从而直接得到我们需要的数组:
参考官方说明:https://docs.python.org/zh-cn/3.8/library/multiprocessing.shared_memory.html
from osgeo import gdal
from osgeo import gdalconst
from zl import zl
import time
from multiprocessing import Pool,shared_memory
import numpy as np
pathin = ''
pathout = ''
a = np.zeros((12,360,720))
shm = shared_memory.SharedMemory(create=True, size=a.nbytes)
b = np.ndarray(a.shape, dtype=a.dtype, buffer=shm.buf)
b[:] = a[:]
def run(i):
existing_shm = shared_memory.SharedMemory(name=shm.name)
c = np.ndarray((12,360,720), dtype=np.float, buffer=existing_shm.buf)
c[:,:,:]=i
existing_shm.close()
if __name__ == '__main__':
print(time.strftime('%Y-%m-%d %H:%M:%S'))
p = Pool(12)
p.map(run, np.arange(12))
p.close()
p.join()
shm.unlink()
np.save(pathout + '1.npy',b)
print(time.strftime('%Y-%m-%d %H:%M:%S'))
简单来说,就是在创建数组的时候将空数组放在共享内存,然后每个进程进行计算都调用这个共享内存中的数据,多个进程可以同时对这个数组进行修改,从而实现进程之间的通信。最终保存的数据就是一个(12,360, 720)的数组,第一维度(12)每个维度上对应的数据就是维度的顺序。