Pytorch加载模型

一、假设我们只保存了模型的参数(model.state_dict())到文件名为modelparameters.pth, model = Net()

1. cpu -> cpu或者gpu -> gpu:

checkpoint = torch.load('modelparameters.pth')

model.load_state_dict(checkpoint)

2. cpu -> gpu 0

checkpoint = torch.load('modelparameters.pth', map_location=lambda storage, loc: storage.cuda(0))

model.load_state_dict(checkpoint)

3. gpu 0 -> gpu 1

checkpoint = torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})

model.load_state_dict(checkpoint)
4. gpu -> cpu
checkpoint = torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)

model.load_state_dict(checkpoint)
二、pytorch完整加载模型实例
# 模型是GPU训练的
def load_models(model_param_path):
    torch.manual_seed(14)
    np.random.seed(14)
    random.seed(14)
    if torch.cuda.is_available():
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(14)

    model = FeatherNetB()
    model = torch.nn.DataParallel(model)
    
    device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else "cpu")
    model.to(device)

    print("=> loading checkpoint '{}'".format(param_path))
    if torch.cuda.is_available():
        #gpu测试
        checkpoint = torch.load(model_param_path)
    else:
        #cpu测试
        checkpoint = torch.load(model_param_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    
    #测试的时候必须加上model.eval(),训练就用model.train()
    model.eval()             
       
    return model

 

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值