Pytorch模型保存与加载模型继续训练

1. 网络模型定义与模型参数保存

定义网络模型与基本参数,以及模型训练和模型保存

使用torch.save()方法保存模型

在save_dict={}中可以保存epoch,model,optimizer,scheduler,loss等参数。

my_net = VisionTransformer()
n_epoch = 200
lr = 0.001
optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epoch, eta_min=lr / 100)
loss_classification = torch.nn.CrossEntropyLoss()

if cuda:
    my_net = my_net.cuda()
    loss_classification = loss_classification.cuda()

for p in my_net.parameters():
    p.requires_grad = True
bestacc = 0.0
savepth = 'mySavepthPath'
for epoch in range(n_epoch):
    my_net.train()
    ....
    if acc > bestacc:
        save_dict = {
            'epoch': epoch,
            'model': my_net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(save_dict, savepth + '.pth')

2. 加载模型继续训练

使用torch.load加载模型,完整代码如下。

要注意的是,要先定义模型和优化器optimizer,把模型放到gpu上,然后再加载模型。
否则执行optimizer.step()时会出现下面这个错误。
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! 
my_net = VisionTransformer()
n_epoch = 200
lr = 0.001
optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epoch, eta_min=lr / 100)
loss_classification = torch.nn.CrossEntropyLoss()

if cuda:
    my_net = my_net.cuda()
    loss_classification = loss_classification.cuda()

Resume = True
start_epoch = -1
if Resume:
    path_checkpoint = 'mySavepthPath.pth'
    checkpoint = torch.load(path_checkpoint, map_location=torch.device('cuda'))
    my_net.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch']
    print("start_epoch:", start_epoch)
    print('-----------------------------')


for p in my_net.parameters():
    p.requires_grad = True

bestacc = 0.0
savepth = 'mySavepthPath'

new_start = 0 if start_epoch == -1 else start_epoch
for epoch in range(start_epoch + 1, new_start+n_epoch):
    my_net.train()
    ....
    if acc > bestacc:
        save_dict = {
            'epoch': epoch,
            'model': my_net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(save_dict, savepth + '.pth')

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值