手撕代码:deep image matting (5)train和valid函数

这两个坑放到一起说,因为对应的代码长得实在是太像了。先上代码细扣一下。

def train(train_loader, model, optimizer, epoch, logger):
    model.train()  # train mode (dropout and batchnorm is used)

    losses = AverageMeter()

    # Batches
    for i, (img, alpha_label) in enumerate(train_loader):
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 4, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(device)  # [N, 320, 320]
        alpha_label = alpha_label.reshape((-1, 2, im_size * im_size))  # [N, 320*320]

        # Forward prop.
        alpha_out = model(img)  # [N, 3, 320, 320]
        alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        loss = alpha_prediction_loss(alpha_out, alpha_label)

        # Back prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        clip_gradient(optimizer, grad_clip)

        # Update weights
        optimizer.step()

        # Keep track of metrics
        losses.update(loss.item())

        # Print status

        if i % print_freq == 0:
            status = 'Epoch: [{0}][{1}/{2}]\t' \
                     'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader), loss=losses)
            logger.info(status)

    return losses.avg


def valid(valid_loader, model, logger):
    model.eval()  # eval mode (dropout and batchnorm is NOT used)

    losses = AverageMeter()

    # Batches
    for img, alpha_label in valid_loader:
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 3, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(device)  # [N, 320, 320]
        alpha_label = alpha_label.reshape((-1, 2, im_size * im_size))  # [N, 320*320]

        # Forward prop.
        alpha_out = model(img)  # [N, 320, 320]
        alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        loss = alpha_prediction_loss(alpha_out, alpha_label)

        # Keep track of metrics
        losses.update(loss.item())

    # Print status
    status = 'Validation: Loss {loss.avg:.4f}\n'.format(loss=losses)

    logger.info(status)

    return losses.avg

抛开多处细节不谈,这两块代码的思路几乎完全一致,模型走一圈之后调用损失函数输出相应的损失值,不同的是train函数反向传播,valid就仅仅更新了损失就结束了。所以这两个函数的区别必须要进行深究。

首先说一下train和valid在整个深度学习的功用。首先用黄海广老师的资料图片来说明

 看到这里大概其对train函数和valid函数有个了解了,对应的就是训练集和验证集的训练过程。在写这个文章的时候由于事先并没有注意到这点所以对这两个函数的目的表示不知所以了一段时间。现在继续往下看。

先来看train函数的代码实现。

def train(train_loader, model, optimizer, epoch, logger):
    model.train()  # train mode (dropout and batchnorm is used)

    losses = AverageMeter()

    # Batches
    for i, (img, alpha_label) in enumerate(train_loader):
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 4, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(device)  # [N, 320, 320]
        alpha_label = alpha_label.reshape((-1, 2, im_size * im_size))  # [N, 320*320]

        # Forward prop.
        alpha_out = model(img)  # [N, 3, 320, 320]
        alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        loss = alpha_prediction_loss(alpha_out, alpha_label)

        # Back prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        clip_gradient(optimizer, grad_clip)

        # Update weights
        optimizer.step()

        # Keep track of metrics
        losses.update(loss.item())

        # Print status

        if i % print_freq == 0:
            status = 'Epoch: [{0}][{1}/{2}]\t' \
                     'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader), loss=losses)
            logger.info(status)

    return losses.avg

整个train函数的运行逻辑特别明显的普遍的模型训练逻辑。这里面我的经验就是先不要在乎那么多细节,就把他想象成把大象装冰箱一样,剩下的细节在一点一点的抠。

 第一条model.train(),这是必须要写入的一个语句,在训练函数加上这个就说明一件事:这个训练函数里将会启用batch normalization和drop out。

valid函数的结构和train函数大体上不能说是基本一致也差不多一模一样,所以这里面就不多提。与此同时在valid函数的第一句用上了model.eval(),也就是说在valid函数里面训练的时候会使用batch normalization 但不会drop out。

valid函数结构如下,对比着train函数看看。

后面再看看两个函数里面的for循环的细节,先看train函数的。

    for i, (img, alpha_label) in enumerate(train_loader):
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 4, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(device)  # [N, 320, 320]
        alpha_label = alpha_label.reshape((-1, 2, im_size * im_size))  # [N, 320*320]

        # Forward prop.
        alpha_out = model(img)  # [N, 3, 320, 320]
        alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        loss = alpha_prediction_loss(alpha_out, alpha_label)

        # Back prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        clip_gradient(optimizer, grad_clip)

        # Update weights
        optimizer.step()

        # Keep track of metrics
        losses.update(loss.item())

        # Print status

        if i % print_freq == 0:
            status = 'Epoch: [{0}][{1}/{2}]\t' \
                     'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader), loss=losses)
            logger.info(status)

for循环可以这么说就是为了单独在dataloader里面进行循环计算损失函数的数值。第一步:先把dataloader里面的img,alpha_label进行处理,第二步:把img放进模型里走一圈,得到模型出来的结果alpha_out,第三步:将alpha_loss和alpha_out进行损失函数的计算,第四步:反向传播更新权重以及损失值。一般第四步就是标准模板的三行代码:

  1. optimizer.zero_grad() 清空过往梯度;

  2. loss.backward() 反向传播,计算当前梯度;

  3. optimizer.step() 根据梯度更新网络参数

这里面就先不提train_loader和valid_loader,这两个函数具体如何运作需要很长的篇幅去搞,所以直接看for循环内部的运作(也就暂且忽略为什么reshape定下的参数是那些)。下一步重点要知道损失函数是怎么得出来的。要知道对于一个模型算法,第一关注是模型的结构,第二是对数据集的要求以及处理,第三就是损失函数的建立,所以对于损失函数是怎么构建的尤其重要。这里面train函数和valid函数对于求出损失值使用的都是一个函数:alpha_prediction_loss(),直接看这个函数的运作

def alpha_prediction_loss(y_pred, y_true):
    mask = y_true[:, 1, :]
    diff = y_pred[:, 0, :] - y_true[:, 0, :]
    diff = diff * mask
    num_pixels = torch.sum(mask)
    return torch.sum(torch.sqrt(torch.pow(diff, 2) + epsilon_sqr)) / (num_pixels + epsilon)

对比前面for循环里的用法就能得出:y_pred是模型得出来的值,y_true是数据集本身的值,也就是模型得出来的结果和本身数据集的正确答案互相计算得出来差异的数值特征也就是损失值。那么这里面是怎么计算的需要回到论文里面找找。

 论文里对于alpha_prediction_loss的表述就是这个公式,而且论文实际上提到了两个损失函数,另外一个是compositional loss,基本上长得跟他完全一致

 到这里就瞬间明白一件事:得,老美又偷懒了,就写了一个损失函数代码把两个都搞出来。最后两个损失函数以一个加权平均得出来总体的损失函数:Loverall = wl ·Lα+(1wl)·Lc

回到代码里面,坏菜了,看不懂了。。。。由于数据集的具体实现还不太清楚,按照代码的推算,diff还好说应该是两个数值的差别,mask到底是做什么的就不知道了。这里先用opencv的概念来解释一下,mask在opencv里面就是为了抠图的蒙版值相乘,那么这个东西很可能就是把最后的抠图搞出来的操作。是不是具体真的是这样,就得把数据集深挖出来搞定了。

最后总结一下,train和valid函数都是针对数据集的训练代码,不同的是对应的数据集训练目的不同,也就是说其实这两个运行逻辑区别真的不大,下一步就要死扣数据集的具体实现了。

  • 6
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值