Pytorch学习——断点续训

Pytorch学习——断点续训

1、保存断点模型状态

 for epoch in range(starting_epoch, parse.epochs):
        for step, (x, y) in enumerate(train_loader):
           ...
        if epoch % 1 == 0:
        	#就是在保存模型的时候记录当前的,模型信息 checkpoint 
	        
	        checkpoint = {
	            'epoch': epoch,
	            'model_state_dict': model.state_dict(),
	            'optimizer_state_dict': optimizer.state_dict(),
	        }
	        # 保存模型
	        
	        model_path = "./model/" + "model_{:03d}.pt".format(epoch)
	        torch.save(checkpoint, model_path)

2、读取状态

# 加载断点模型
checkpoint_path = "./model/model_" + parse.checkpoint + ".pt"
train_state = torch.load(checkpoint_path)
# 加载断点的状态
model.load_state_dict(train_state['model_state_dict'])
optimizer.load_state_dict(train_state['optimizer_state_dict'])
starting_epoch = train_state['epoch'] + 1

3、完整的例程

import torch
from torch import optim, nn
# import visdom
import torchvision
from torch.utils.data import DataLoader

from defect import DefectDataset
from resnet import ResNet18
import matplotlib.pyplot as plt
import argparse


def parse_args():
    # 1、使用argparse的第一步是创建一个 ArgumentParser对象:
    parser = argparse.ArgumentParser(description="defeat detect")
    # 2、添加一个位置参数,默认视为字符串
    parser.add_argument("echo", help="echo the string you use here")
    # 3、添加一个整数位置参数
    parser.add_argument("--epochs", default=10, help="train epochs", type=int)
    # 4、添加可选参数 ,指定一个新关键字action,并为其指定值 "store_true"。这意味着,如果指定了该选项,则将值分配True给args.verbose。不指定它暗含False。
    parser.add_argument("--resume", help="increase output verbosity",
                        action="store_true")
    parser.add_argument("--checkpoint", default="0", help="checkpoint num")
    return parser.parse_args()


def evalute(model, loader, device):
    model.eval()
    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total


def train(parse, device, train_loader, val_loader, train_state=None):
    model = ResNet18().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    best_acc, best_epoch = 0, 0
    global_step = 0
    train_loss_list = []
    acc_val_list = []
    starting_epoch = 0
    if parse.resume:
        checkpoint_path = "./model/model_" + parse.checkpoint + ".pt"
        train_state = torch.load(checkpoint_path)
    if train_state is not None:
        model.load_state_dict(train_state['model_state_dict'])
        optimizer.load_state_dict(train_state['optimizer_state_dict'])
        starting_epoch = train_state['epoch'] + 1
    for epoch in range(starting_epoch, parse.epochs):
        for step, (x, y) in enumerate(train_loader):
            # print(x.size(),y.size())
            x, y = x.to(device), y.to(device)
            # print(x.size(),y.size())

            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
            train_loss_list.append(loss.item())
            print('epoch: ', epoch, 'step:', step, 'loss: ', loss.item())
        if epoch % 1 == 0:
            val_acc = evalute(model, val_loader, device)
            acc_val_list.append(val_acc)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                # torch.save(model.state_dict(), 'best.mdl')
                model_path = "./model/" + "model_{:03d}.pt".format(epoch)
                torch.save(checkpoint, model_path)

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    # model.load_state_dict(torch.load('best.mdl'))
    # print('loaded from ckpt!')


if __name__ == '__main__':
    parse = parse_args()
    print(parse.echo)
    epochs = parse.epochs
    batchsz = 8
    lr = 1e-3
    device = torch.device('cuda:0')
    torch.manual_seed(1234)
    train_db = DefectDataset('data', 416, mode='train')
    val_db = DefectDataset('data', 416, mode='val')
    test_db = DefectDataset('data', 416, mode='test')

    train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=8)
    val_loader = DataLoader(train_db, batch_size=batchsz, num_workers=8)
    test_loader = DataLoader(train_db, batch_size=batchsz, num_workers=8)

    train(parse, device, train_loader, val_loader)

  • 6
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FlyDremever

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值