解决pytorch多GPU训练的模型加载问题

本文介绍了在PyTorch中使用多GPU训练模型时,如何通过nn.DataParallel进行模型封装,并详细阐述了训练好的模型在单GPU环境下如何正确加载。主要涉及两种保存模型的方法,以及针对多GPU和单GPU模型状态字典的不同加载策略。
摘要由CSDN通过智能技术生成

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

重点多GPU训练好的模型,单GPU 如何正确加载

方法一:(训练的时候必须写好的不同保存模式)

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

或者写入字典

def save_model (model, cudan=4):
 savepath = str(dir_checkpoint) + '/best_model.pth'
  # 定义要保存的模型的字典
 state = {
      'epoch': nb + 1,
      'mIoU': newmIoU,
      'dev_loss': dev_loss,
      "lr:":lr,
      # 'model_state_dict':  model.module.state_dict(),  # 保存多GPU网络模型的字典
      # 'model_state_dict': model.state_dict(),  # 保存单GPU模型的字典
      'optimizer_state_dict': optimizer.state_dict(),
  }
  # 保存网络模型 https://blog.csdn.net/anshiquanshu/article/details/122157157
  if cudan > 1:  # 并行的保存
      state['model_state_dict'] = model.module.state_dict()  # 多GPU
  else:
      state['model_state_dict'] = model.state_dict() # 单GPU模型的字典

  torch.save(state, savepath) 

方法二:无论有几个GPU 都按并行的方式加载即可,一个也可以,不影响

model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint["model_state_dict"].items()})
    model = nn.DataParallel(model).cuda()

或者字典中的加载

model2 = net()
model2.load_state_dict({k.replace('module.', ''):v for k, v in torch.load('demo.pth').items()})
model2 = nn.DataParallel(model2).cuda()

参考链接:
[1]https://blog.csdn.net/anshiquanshu/article/details/122157157
[2] https://www.jb51.net/article/189297.htm

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值