pytorch训练神经网络——模型保存及加载

官方推荐使用:

# 保存网络中的参数, 速度快,占空间少

torch.save(model.state_dict(),PATH)

对应的加载模型代码则为

model_dict=model.load_state_dict(torch.load(PATH))

 

示例

#训练模型
# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

model.train()

torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},PATH)

#测试模型
#load model
def load_checkpoint(model, checkpoint_PATH, optimizer):
    if checkpoint != None:
        model_CKPT = torch.load(checkpoint_PATH)
        model.load_state_dict(model_CKPT['state_dict'])
        print('loading checkpoint!')
        optimizer.load_state_dict(model_CKPT['optimizer'])
    return model, optimizer


model = TheModelClass()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

model, optimizer=load_checkpoint(model, PATH, optimizer)

model.test()

此时若出现报错“Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!”。则说明存储和加载模型使用了不同的设备。

1)GPU保存,CPU加载

加载模型代码

    device=torch.device("cup")
    model= TheModelClass()
    model.load_state_dict(torch.load(PATH,map_location=device))

2)GPU保存,GPU加载

加载模型代码

    device=torch.device("cuda")
    model= TheModelClass()
    model.load_state_dict(torch.load(PATH))
    model.to(device)

3)CPU保存,GPU加载

加载模型代码

    device=torch.device("cuda")
    model= TheModelClass()
    model.load_state_dict(torch.load(PATH,map_location="cuda:0"))
    model.to(device)

 

 

 

 

 

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值