在pytorch中load超大训练数据
by joeyqzhou
相关代码地址: https://github.com/joeyqzhou/blog/tree/master/pytorch%E4%B8%ADload%E8%B6%85%E5%A4%A7%E8%AE%AD%E7%BB%83%E6%95%B0%E6%8D%AE
最简单方式:
1 单线程获取数据到内存中
2 train的过程
for epoch in range(num_epochs):
for i in range(inst_size):
#截取 batch_x, batch_y
#batch_x, batch_y 转换为tensor
#model.forward()
#loss.backward()
#optimizer.step()
这种方式代码简单。缺点load数据过慢,数据全部存储在内存当中。
当训练数据过大的时候load很慢,内存会溢出
多进程load数据
如下是一个多进程load数据的例子
from multiprocessing import Pool
def process_line(line):
return "FOO: %s" % line
if __name__ == "__main__":
pool = Pool(4)
file = "train.txt" #你的输入数据
ret = []
with open(file) as source_file:
# chunk the work into batches of 4 lines at a time
results = pool.map(process_line, source_file, 4)