Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu 多块GPU训练的模型转成单块或其他GPU数量需求的模型
标签 : pytorch nn.Dataparalle model.state_dict
参考: reference link
问题描述
我们在用Pytorch训练模型的时候,可能有几组服务器,每个服务器显卡GPU配置和数量不一样,而 nn.Dataparallel保存的模型又是和显卡数量挂钩的,实际上我们需要模型能够随便转移到不同显卡数量的服务器上运行测试。下面的代码正是针对这个问题的
解决方案
重写 nn.Module内的state_dict, load_state_dict函数。也就是说,我们保存和加载的模型是不经过nn.DataParallel处理过的,所以可以在任意GPU数量上进行加载训练的。
如果你已经用多卡训了,那你只需要把下面的代码copy一下,然后再运行一个epoch即可
import torch
import torch.nn as nn