基础模板使用

本文介绍了在深度学习项目中如何保存模型状态、检查点的管理,包括最佳模型更新、模型参数检查、以及学习率调整和SBN(标准化Batch归一化)的实现。
摘要由CSDN通过智能技术生成

保存模型

def main(
        if valid_loss < best_loss:
            is_best = True
            best_epoch = epoch

        best_prec = min(valid_loss,best_loss)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec': best_prec,
            'optimizer': optimizer.state_dict(),
        }, is_best, fdir)
        
def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        torch.save(state, os.path.join(fdir, 'model_best.pth.tar'))
        logger.info("Best model is updated")

查看更改模型

for key in model.state_dict():
    print(key, model.state_dict()[key].size())
for key in checkpoint:
    print(key, checkpoint[key].shape)
for var_name in optimizer.state_dict():
        print(var_name,'\t',optimizer.state_dict()[var_name])

keys = list(checkpoint.keys())
for key in keys:
   if 'thresh' in key:
       checkpoint[key[:-6] + 'up'] = checkpoint.pop(key)
#名字全是乱的,但是参数数量一样一一对应:
key_index = list(model.state_dict().keys())
for item, (key, value) in enumerate(checkpoint.items()):
    model_key = key_index[item]
	model.load_state_dict({model_key: value}, strict=False)

加载模型

	model = modelpool(args.model, args.dataset)
    start_epoch = 0
    best_prec = 0
    quan.replace_modules(model, args)
    optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.resume.path:
        # model, start_epoch, _ = util.load_checkpoint(
        #     model, args.resume.path, 'cuda', lean=args.resume.lean)
        checkpoint = torch.load(args.resume.path)  # 加载保存的检查点文件
        model.load_state_dict(checkpoint['state_dict'])  # 加载模型的状态字典
        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器的状态字典
        start_epoch = checkpoint['epoch']  # 获取加载的 epoch 数
        best_prec = checkpoint['best_prec']  # 获取加载的最佳准确率
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        print(f"start epoch: {start_epoch}, lr: {optimizer.param_groups[0]['lr']}, best_prec: {best_prec}")

    model.cuda()

adjust_lr

def adjust_learning_rate(optimizer, epoch, lr):
    lr *= (0.1 ** (epoch // 60))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def lr_scheduler(optimizer, epoch):
    lr_list = [100, 140, 240]
    if epoch in lr_list:
        print('change the learning rate')
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1 
      
for epoch in range(start_epoch, args.epoch):
     adjust_learning_rate(optimizer, epoch, args.lr)
     lr = optimizer.param_groups[0]['lr']

cosine:

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)	

SBN in validation

class SBN(nn.Module):
    def __init__(self, bn, T) -> None:
        super().__init__()
        self.eps = bn.eps
        self.weight = bn.weight
        self.bias = bn.bias / T 
        self.val_mean = bn.running_mean / T 
        self.val_var = bn.running_var
    def forward(self, x):
        if self.training:
            val_mean = x.mean([0, 2])
            val_var = x.var([0, 2], unbiased=False)
        x = x - self.val_mean[None, ..., None]
        x = x / torch.sqrt(self.val_var[None, ..., None] + self.eps)
        x = x * self.weight[..., None] + self.bias[..., None]
        return x
         with open(f"model_params.txt", "w") as fo:
             fo.write(f"{len(model.state_dict())}")
            for key in model.state_dict():
                     fo.write(f"{key}\t{model.state_dict()[key].size()}\n")
         with open(f"checkpoint_params.txt", "w") as f:
             f.write(f"{len(checkpoint)}")
             for key in checkpoint:
                 f.write(f"{key}\t{checkpoint[key].shape}\n")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值