import multiprocessing as tmp
import torch as t
import numpy as np
import time
def sample(args):
rank,meta_works,batch_size,indexes,indexes2,dataset=args
array = []
for i in range(rank * meta_works, min((rank + 1) * meta_works, batch_size)):
elem = dataset[indexes[i]:indexes2[i]]
array.append(elem)
return t.stack(array)
class SampleClass():
def __init__(self):
self.length=10000
self.batch_size = 200
self.dataset=t.zeros(size=(self.length,84,84)).type(t.uint8)+t.randint(256,size=(self.length,)).unsqueeze(-1).unsqueeze(-1).type(t.uint8)
self.num_processes=8
self.num_repeat=1000
self.indexes = np.random.randint(self.length - 8, size=self.batch_size)
self.indexes2 = self.indexes + 4
self.meta_works=int(np.ceil(self.batch_size/self.num_processes))
self.pool=tmp.Pool(processes=self.num_processes)
self.result = self.pool.map_async(sample, [(i, self.meta_works, self.batch_size, self.indexes, self.indexes2, self.dataset) for i in range(self.num_processes)])
def method1(self):
# pool = tmp.Pool(processes=self.num_processes)
for repeat in range(self.num_repeat):
self.indexes = np.random.randint(self.length - 8, size=self.batch_size)
self.indexes2 = self.indexes + 4
self.result = self.pool.map_async(sample, [(i, self.meta_works, self.batch_size, self.indexes, self.indexes2, self.dataset) for i in range(self.num_processes)])
result=self.result.get()
result=t.cat(result)
# print(result[0,0,0,0])
# print(result.shape)
def method2(self):
for repeat in range(self.num_repeat):
self.indexes = np.random.randint(self.length - 8, size=self.batch_size)
self.indexes2 = self.indexes + 4
array = []
for batch_idx in range(self.batch_size):
elem=self.dataset[self.indexes[batch_idx]:self.indexes2[batch_idx]]
array.append(elem)
array=t.stack(array)
# print(array.shape)
sampleclass=SampleClass()
starttime=time.time()
sampleclass.method1()
spendtime=time.time()-starttime
print('method1 time: {:.6f}'.format(spendtime))
starttime=time.time()
sampleclass.method2()
spendtime=time.time()-starttime
print('method2 time: {:.6f}'.format(spendtime))
結果如下
method1 time: 16.703726
method2 time: 1.195947