将多个GPU上用pytorch框架并行训练的神经网络模型应用到CPU上

本人用pytorch框架在两块GPU上并行训练了一个神经网络模型,并将训练的不同阶段的结果保存起来,以便用于模型集成。

虽然模型是在GPU上训练的,但是在服务器上部署的时候只需用CPU就可以进行模型推断。但在实际应用中,却出现如下报错信息:

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

解决的思路是在服务器上先将模型加载进来,然后用一种新的方式重新保存。见下面的代码段:

    model_list = ['model_1.tar','model_2.tar','model_3.tar','model_4.tar','model_5.tar']
    model_path = './sel_models/'

    # set model
    device = torch.device('cpu')
       
    for model_name in model_list:
        model = Net(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
        model = nn.DataParallel(model)
        
        # load trained model
        checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        del checkpoint
        model = model.to(device, dtype=torch.float)
        
        cpu_model_path = './model_for_cpu/'
        if not os.path.exists(cpu_model_path):
            os.mkdir(cpu_model_path)
        
        torch.save({'model_state_dict': model.module.state_dict()},os.path.join(cpu_model_path, model_name))

这里需要注意的是 必须要加上 model = nn.DataParallel(model), 因为模型是在双GPU上并行训练的,不加这句话模型加载就会出错。另外重新保存的时候一定要加上 'module', 即 model.module.state_dict(), 而不是model.state_dict(),这也是解决这个错误的关键。

本文参考:

(26条消息) pytorch加载多GPU模型和单GPU模型(遗漏module的解决)_律己且好学,才能保证不坠入愤世嫉俗之列。-CSDN博客icon-default.png?t=L9C2https://blog.csdn.net/qq_18649781/article/details/90270323?spm=1001.2101.3001.6650.1&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-1.no_search_link&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-1.no_search_link

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值