Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu

reference linkimport torchimport torch.nn as nnfrom collections import OrderedDictfrom torch.nn.parameter import Parameterdef state_dict(model, destination=None, prefix='', keep_vars=False):...
摘要由CSDN通过智能技术生成

Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu 多块GPU训练的模型转成单块或其他GPU数量需求的模型

标签 : pytorch nn.Dataparalle model.state_dict


参考: reference link

问题描述

我们在用Pytorch训练模型的时候,可能有几组服务器,每个服务器显卡GPU配置和数量不一样,而 nn.Dataparallel保存的模型又是和显卡数量挂钩的,实际上我们需要模型能够随便转移到不同显卡数量的服务器上运行测试。下面的代码正是针对这个问题的

解决方案

重写 nn.Module内的state_dict, load_state_dict函数。也就是说,我们保存和加载的模型是不经过nn.DataParallel处理过的,所以可以在任意GPU数量上进行加载训练的。
如果你已经用多卡训了,那你只需要把下面的代码copy一下,然后再运行一个epoch即可

import torch
import torch.nn as nn
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值