手撕代码1:Deep image matting(3)

昨天说完了args的作用,今天就继续开这个大坑的重点内容:train_net

 现在就先看看train_net整个函数是怎么运行的

def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_loss = float('inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0
    decays_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        model = DIMModel(n_classes=1, in_channels=4, is_unpooling=True, pretrain=True)
        model = nn.DataParallel(model)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model'].module
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)

    # Custom dataloaders
    train_dataset = DIMDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
    valid_dataset = DIMDataset('valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        if args.optimizer == 'sgd' and epochs_since_improvement == 10:
            break

        if args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
            checkpoint = 'BEST_checkpoint.tar'
            checkpoint = torch.load(checkpoint)
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
            decays_since_improvement += 1
            print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))
            adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)

        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger)
        effective_lr = get_learning_rate(optimizer)
        print('Current effective learning rate: {}\n'.format(effective_lr))

        writer.add_scalar('Train_Loss', train_loss, epoch)
        writer.add_scalar('Learning_Rate', effective_lr, epoch)

        # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           logger=logger)

        writer.add_scalar('Valid_Loss', valid_loss, epoch)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0
            decays_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)

前面那些变量经过查找资料之后注释如下:

 epochs_since_improvement和decays_since_improvement在后续的遍历有所体现,到后面再具体说是做什么的,end_epoch在上一讲中的argparse里面有所涉及,换句话说就类似于提前默认好了一个变量直接用,而且在参数修改的时候用命令行就能进行修改。

下一步会先去判断有没有检查点也就是事先练好的模型,如果没有的话就创建一个模型model,然后再判断优化器的类型来决定模型使用的优化器。由于篇幅关系DIM_MODEL这个今天就不做详解,整体的角度捋一遍整个train_net函数的结构。

这里面有一句dataparallel,这个是为了让模型能在多个gpu运行,因为用一个gpu跑dim的话显存有限而且时间太长,为了方便训练一般都使用多个显卡一起炼丹,笔者之前试过8块3090的效果,真的快如闪电。。。。

数据集dataloader的具体实现本期跳过,直接到后面的重点内容:训练过程

# Epochs
    for epoch in range(start_epoch, args.end_epoch):
        if args.optimizer == 'sgd' and epochs_since_improvement == 10:
            break

        if args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
            checkpoint = 'BEST_checkpoint.tar'
            checkpoint = torch.load(checkpoint)
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
            decays_since_improvement += 1
            print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))
            adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)

        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger)
        effective_lr = get_learning_rate(optimizer)
        print('Current effective learning rate: {}\n'.format(effective_lr))

        writer.add_scalar('Train_Loss', train_loss, epoch)
        writer.add_scalar('Learning_Rate', effective_lr, epoch)

        # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           logger=logger)

        writer.add_scalar('Valid_Loss', valid_loss, epoch)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0
            decays_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)

乍一看很简单,实际上就是各种套娃。。。初始的状态如下:

start_epoch=0

end_epoch = 100(可以命令行设定)

epochs_since_improvement = 0

decays_since_improvement = 0

从这里就必须要弄明白一个关键的事情:为什么要设置since_improvement这类的变量,上来判断说如果epochs_since_improvement==10的时候训练就停,这是必须要思考的问题。那就只盯着epochs_since_improvement和decays_since_improvement,直到整个代码的最后一块才找到问题的所在。

 # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           logger=logger)

        writer.add_scalar('Valid_Loss', valid_loss, epoch)

# Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0
            decays_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)

之前说过,best_loss初始的时候是正无穷为了方便后续的损失也就是valid_loss进行更新,那么如果valid_loss小于best_loss,那么此时此刻best_loss更新为更小的数值,然后is_best会变成1,此时此刻epochs_since_improvement和decays_since_improvement就会更新为0,反之如果best_loss更小,那么epochs_since_improvement就会加1

再回到开头,如果epochs_since_improvement==10的时候就终止循环,也就是说这里面有10次的损失值没法更新了,那么这个变量的作用就体现出来:避免过多的训练导致资源的浪费,既然有连续十次的损失函数没法更新,那就没必要接着玩。

那么decays_since_improvement这个东西又是咋回事?往前面看看。

if args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
            checkpoint = 'BEST_checkpoint.tar'
            checkpoint = torch.load(checkpoint)
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
            decays_since_improvement += 1
            print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))
            adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)

此时此刻在进入循环的时候如果之前出来的epochs_since_improvement大于0且能被2整除(正偶数),就直接在checkpoint里面进行运作。因为我从没用过他提供的checkpoint,而且正常运行的话损失函数在每一次的循环之后都会朝着更低的方向来跑,所以这里面我的猜测就是因为使用了已经训练好的模型的checkpoint,所以在训练的时候就会出现多次的最佳损失值无法更新,因此在运行的时候直接调用checkpiont里面的参数。后续再看很多代码的训练函数都有这么写的,因此这一功能就显得特别重要。

其实到了这块整个训练的代码结构的大概就已经展现在眼前了。在前期准备工作就绪之后,直接在每一个epoch里面进行模型循环然后再得到损失值进行更新,但是这里面的细节还得放到后面填坑

1.dim的模型结构

2.train和valid都是做什么的

3.writer到底有什么作用

这些放到后面填坑吧 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值