从零实现无监督光流pipline(2):训练代码,损失函数代码

前言

我们这里直接使用PWCNet了,先了解一下基线和接口啥的,后面就可以换成自己可以创新的模型了。我们借助的ARFlow的工程,里面真的很贴心呀,没有使用AR的和使用AR的两个版本,我们这里自然就是使用前者嘛!因为其实我们要进行方法创新的时候也不能直接套用它的方法呀。

模型部分

这里其实最主要的就是输入一个图像序列,然后返回一个光流金字塔字典,我们这里直接copy的ARFlow的代码,详细的去看下工程,这里就不附上了,因为这个地方的变化是最大的。

训练部分:base

ARFlow的训练代码写的好清晰呀,和数据dateset一样,我们这里先创建一个base的类(可复用),老规矩直接附上代码注释。

import torch
import numpy as np
from abc import abstractmethod
from tensorboardX import SummaryWriter
from utils.torch_utils import bias_parameters, weight_parameters, \
    load_checkpoint, save_checkpoint, AdamW


class BaseTrainer:
    """
    Base class for all trainers
    """

    def __init__(self, train_loader, valid_loader, model, loss_func,
                 _log, save_root, config):
        self._log = _log
        # 配置文件
        self.cfg = config
        # 保存目录
        self.save_root = save_root
        # tensorboardX对象(可视化工具)
        self.summary_writer = SummaryWriter(str(save_root))
        # 数据dataset
        self.train_loader, self.valid_loader = train_loader, valid_loader
        # 设备
        self.device, self.device_ids = self._prepare_device(config['n_gpu'])
        # 模型实例
        self.model = self._init_model(model)
        # 优化器实例
        self.optimizer = self._create_optimizer()
        # 损失函数
        self.loss_func = loss_func

        # 当错误小于这个,就停止训练
        self.best_error = np.inf

        self.i_epoch = 0
        self.i_iter = 0

    @abstractmethod
    def _run_one_epoch(self):
        # 训练一个epoch
        ...

    @abstractmethod
    def _validate_with_gt(self):
        # 验证函数
        ...

    def train(self):
        for epoch in range(self.cfg.epoch_num):
            # 运行一个训练周期
            self._run_one_epoch()
            # 判断是否到达进行验证的时机
            if self.i_epoch % self.cfg.val_epoch_size == 0:
                # 使用真实标签进行验证,获取验证集上的性能指标
                errors, error_names = self._validate_with_gt()
                # 格式化输出验证结果
                valid_res = ' '.join(
                    '{}: {:.2f}'.format(*t) for t in zip(error_names, errors))
                # 打印验证结果信息
                self._log.info(' * Epoch {} '.format(self.i_epoch) + valid_res)

    # 初始化模型
    def _init_model(self, model):
        model = model.to(self.device)
        # 使用预训练模型权重
        if self.cfg.pretrained_model:
            self._log.info("=> using pre-trained weights {}.".format(
                self.cfg.pretrained_model))
            # 加载模型权重
            epoch, weights = load_checkpoint(self.cfg.pretrained_model)

            from collections import OrderedDict
            # 创建一个有序字典
            new_weights = OrderedDict()
            # model.state_dict()获取模型的当前状态字典,包含了模型的所有权重参数
            # 获取状态字典的名称列表
            model_keys = list(model.state_dict().keys())
            # 获取权重名称列表
            weight_keys = list(weights.keys())
            # 将预训练模型状态字典和参数字典打包
            for a, b in zip(model_keys, weight_keys):
                new_weights[a] = weights[b]
            weights = new_weights
            # 模型加载预训练权重
            model.load_state_dict(weights)
        # 从头开始训练模型
        else:
            self._log.info("=> Train from scratch.")
            # 模型初始化
            model.init_weights()
        # 多GPU设置
        model = torch.nn.DataParallel(model, device_ids=self.device_ids)
        return model

    # 类内操作:创建优化器
    def _create_optimizer(self):
        self._log.info('=> setting Adam solver')
        param_groups = [
            {'params': bias_parameters(self.model.module),
             'weight_decay': self.cfg.bias_decay},
            {'params': weight_parameters(self.model.module),
             'weight_decay': self.cfg.weight_decay}]

        if self.cfg.optim == 'adamw':
            optimizer = AdamW(param_groups, self.cfg.lr,
                              betas=(self.cfg.momentum, self.cfg.beta))
        elif self.cfg.optim == 'adam':
            optimizer = torch.optim.Adam(param_groups, self.cfg.lr,
                                         betas=(self.cfg.momentum, self.cfg.beta),
                                         eps=1e-7)
        else:
            raise NotImplementedError(self.cfg.optim)
        return optimizer

    def _prepare_device(self, n_gpu_use):
        """
        setup GPU device if available, move model into configured device
        """
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self._log.warning("Warning: There\'s no GPU available on this machine,"
                              "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self._log.warning(
                "Warning: The number of GPU\'s configured to use is {}, "
                "but only {} are available.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids
    
    # 保存模型
    def save_model(self, error, name):
        is_best = error < self.best_error

        if is_best:
            self.best_error = error

        models = {'epoch': self.i_epoch,
                  'state_dict': self.model.module.state_dict()}

        save_checkpoint(self.save_root, models, name, is_best)

有些代码是在ARFlow工程中的utils中。我将上面要用到的代码直接复制下

# torch_utils
import torch
import shutil
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import numbers
import random
import math
from torch.optim import Optimizer


def init_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def weight_parameters(module):
    return [param for name, param in module.named_parameters() if 'weight' in name]


def bias_parameters(module):
    return [param for name, param in module.named_parameters() if 'bias' in name]


def load_checkpoint(model_path):
    weights = torch.load(model_path)
    epoch = None
    if 'epoch' in weights:
        epoch = weights.pop('epoch')
    if 'state_dict' in weights:
        state_dict = (weights['state_dict'])
    else:
        state_dict = weights
    return epoch, state_dict


def save_checkpoint(save_path, states, file_prefixes, is_best, filename='ckpt.pth.tar'):
    def run_one_sample(save_path, state, prefix, is_best, filename):
        torch.save(state, save_path / '{}_{}'.format(prefix, filename))
        if is_best:
            shutil.copyfile(save_path / '{}_{}'.format(prefix, filename),
                            save_path / '{}_model_best.pth.tar'.format(prefix))

    if not isinstance(file_prefixes, str):
        for (prefix, state) in zip(file_prefixes, states):
            run_one_sample(save_path, state, prefix, is_best, filename)

    else:
        run_one_sample(save_path, states, file_prefixes, is_best, filename)


def restore_model(model, pretrained_file):
    epoch, weights = load_checkpoint(pretrained_file)

    model_keys = set(model.state_dict().keys())
    weight_keys = set(weights.keys())

    # load weights by name
    weights_not_in_model = sorted(list(weight_keys - model_keys))
    model_not_in_weights = sorted(list(model_keys - weight_keys))
    if len(model_not_in_weights):
        print('Warning: There are weights in model but not in pre-trained.')
        for key in (model_not_in_weights):
            print(key)
            weights[key] = model.state_dict()[key]
    if len(weights_not_in_model):
        print('Warning: There are pre-trained weights not in model.')
        for key in (weights_not_in_model):
            print(key)
        from collections import OrderedDict
        new_weights = OrderedDict()
        for key in model_keys:
            new_weights[key] = weights[key]
        weights = new_weights

    model.load_state_dict(weights)
    return model


class AdamW(Optimizer):
    """Implements AdamW algorithm.

    It has been proposed in `Fixing Weight Decay Regularization in Adam`_.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)

    .. Fixing Weight Decay Regularization in Adam:
    https://arxiv.org/abs/1711.05101
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super(AdamW, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'AdamW does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # according to the paper, this penalty should come after the bias correction
                # if group['weight_decay'] != 0:
                #     grad = grad.add(group['weight_decay'], p.data)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

                if group['weight_decay'] != 0:
                    p.data.add_(-group['weight_decay'], p.data)

        return loss

我们这里得去学习一下光流预测,是写在了一个函数中:

def evaluate_flow(gt_flows, pred_flows, moving_masks=None):
    # credit "undepthflow/eval/evaluate_flow.py"
    # 计算误差率
    def calculate_error_rate(epe_map, gt_flow, mask):
        # 返回那些误差太大的像素mask
        bad_pixels = np.logical_and(
            # 生成一个布尔类型的数组,找出误差图误差大于3的像素点
            epe_map * mask > 3,
            # 将光流估计误差图 epe_map 中的误差值标准化,并且取>0.05的像素点
            epe_map * mask / np.maximum(
                # 求L2范数,并且与1e-10这个很小的常数取最值,防止除0
                np.sqrt(np.sum(np.square(gt_flow), axis=2)), 1e-10) > 0.05)
        # 计算了被定义为 "bad pixels" 的像素在总像素数中的百分比,从而得到一个误差率的度量:%。
        return bad_pixels.sum() / mask.sum() * 100.

    # 总体误差、非遮挡误差、遮挡误差、运动误差、静态误差和总体误差率
    error, error_noc, error_occ, error_move, error_static, error_rate = \
        0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    # 运动误差率和静态误差率
    error_move_rate, error_static_rate = 0.0, 0.0
    # 真值光流的数量
    B = len(gt_flows)
    for gt_flow, pred_flow, i in zip(gt_flows, pred_flows, range(B)):
        # 获取预测光流和真值光流的尺寸
        H, W = gt_flow.shape[:2]
        h, w = pred_flow.shape[:2]
        # 其内容与原始对象完全相同,但它们在内存中是独立的
        pred_flow = np.copy(pred_flow)
        # 将光流从原始图像的尺寸缩放到新的图像尺寸,确保光流信息仍然对应于新的图像大小
        pred_flow[:, :, 0] = pred_flow[:, :, 0] / w * W
        pred_flow[:, :, 1] = pred_flow[:, :, 1] / h * H
        # 线性插值 (cv2.INTER_LINEAR),以平滑地调整图像大小。
        flo_pred = cv2.resize(pred_flow, (W, H), interpolation=cv2.INTER_LINEAR)

        # 储了光流估计结果与真实光流之间每个像素位置上的端点误差
        epe_map = np.sqrt(
            np.sum(np.square(flo_pred[:, :, :2] - gt_flow[:, :, :2]),
                   # 在第二轴上求和为了生成map
                   axis=2))

        if gt_flow.shape[-1] == 2:
            # 累加当前图像的平均端点误差到 error 变量中
            error += np.mean(epe_map)

        elif gt_flow.shape[-1] == 4:
            # 计算occ_mask的epe
            error += np.sum(epe_map * gt_flow[:, :, 2]) / np.sum(gt_flow[:, :, 2])
            # 获取没有遮挡的掩码
            noc_mask = gt_flow[:, :, -1]
            # 计算非遮挡误差
            error_noc += np.sum(epe_map * noc_mask) / np.sum(noc_mask)
            # 计算遮挡误差
            error_occ += np.sum(epe_map * (gt_flow[:, :, 2] - noc_mask)) / max(
                np.sum(gt_flow[:, :, 2] - noc_mask), 1.0)
            # 计算错误率
            error_rate += calculate_error_rate(epe_map, gt_flow[:, :, 0:2],
                                               gt_flow[:, :, 2])
            # 计算其他误差
            if moving_masks is not None:
                move_mask = moving_masks[i]

                error_move_rate += calculate_error_rate(
                    epe_map, gt_flow[:, :, 0:2], gt_flow[:, :, 2] * move_mask)
                error_static_rate += calculate_error_rate(
                    epe_map, gt_flow[:, :, 0:2],
                    gt_flow[:, :, 2] * (1.0 - move_mask))

                error_move += np.sum(epe_map * gt_flow[:, :, 2] *
                                     move_mask) / np.sum(gt_flow[:, :, 2] *
                                                         move_mask)
                error_static += np.sum(epe_map * gt_flow[:, :, 2] * (
                        1.0 - move_mask)) / np.sum(gt_flow[:, :, 2] *
                                                   (1.0 - move_mask))

    # 根据光流场的表示方式,返回相应的误差和误差率
    if gt_flows[0].shape[-1] == 4:
        res = [error / B, error_noc / B, error_occ / B, error_rate / B]
        if moving_masks is not None:
            res += [error_move / B, error_static / B]
        return res
    else:
        return [error / B]

这之后我们还得看一下计算和存储平均值的类 AverageMeter,用于跟踪一个或多个指标的平均值和当前值。

import collections

# 跟新字典
def update_dict(orig_dict, new_dict):
    for key, val in new_dict.items():
        if isinstance(val, collections.Mapping):
            tmp = update_dict(orig_dict.get(key, {}), val)
            orig_dict[key] = tmp
        else:
            orig_dict[key] = val
    return orig_dict


class AverageMeter(object):
    """计算和存储平均值"""
    def __init__(self, i=1, precision=3, names=None):
        # 指标数
        self.meters = i
        # 小数点精度
        self.precision = precision
        # 重置
        self.reset(self.meters)
        # 指标名称
        self.names = names
        if names is not None:
            assert self.meters == len(self.names)
        else:
            self.names = [''] * self.meters

    # 数值重置函数
    def reset(self, i):
        # 初始化为一个长度为 i 的列表,所有元素都设置为0
        self.val = [0] * i
        self.avg = [0] * i
        self.sum = [0] * i
        self.count = [0] * i

    # 更新数值
    def update(self, val, n=1):
        if not isinstance(val, list):
            val = [val]
        if not isinstance(n, list):
            n = [n] * self.meters
        assert (len(val) == self.meters and len(n) == self.meters)
        for i in range(self.meters):
            self.count[i] += n[i]
        for i, v in enumerate(val):
            self.val[i] = v
            self.sum[i] += v * n[i]
            self.avg[i] = self.sum[i] / self.count[i]

    # 定义对象的字符串表示形式
    def __repr__(self):
        val = ' '.join(['{} {:.{}f}'.format(n, v, self.precision) for n, v in
                        zip(self.names, self.val)])
        avg = ' '.join(['{} {:.{}f}'.format(n, a, self.precision) for n, a in
                        zip(self.names, self.avg)])
        return '{} ({})'.format(val, avg)

在Python中,repr 是一个特殊方法,用于定义对象的字符串表示形式。当你使用 print(obj) 或 str(obj) 时,Python 将调用该对象的 repr 方法来获取对象的字符串表示。

sintel_trainer

我们先创建一个类

import time
import torch
from .base_trainer import BaseTrainer
from utils.flow_utils import evaluate_flow
from utils.misc_utils import AverageMeter


class TrainFramework(BaseTrainer):
    def __init__(self, train_loader, valid_loader, model, loss_func,
                 _log, save_root, config):
        super(TrainFramework, self).__init__(
            train_loader, valid_loader, model, loss_func, _log, save_root, config)

之后在base中我们写了循环epoch的代码,那么这里我们就要去写一次epoch怎么训练就好了,这个其实也不是固定的,根据不同的任务,要去书写不同的代码,比如ARFlow中的两次前向传播的过程,就和我们下面的写的基础训练不一样的,如果我们后面要进行创新,这里面也是要自己去书写的。

    def _run_one_epoch(self):
        # 创建两个指标对象:batch和时间
        am_batch_time = AverageMeter()
        am_data_time = AverageMeter()
        # 创建关键指标名称:总loss,亮度损失,平滑损失,光流平均
        key_meter_names = ['Loss', 'l_ph', 'l_sm', 'flow_mean']
        # 创建4个关键指标对象
        key_meters = AverageMeter(i=len(key_meter_names), precision=4)
        # 开启训练模式
        self.model.train()
        # 记录时间
        end = time.time()
        # 这段代码的作用是在训练的特定轮次(即 'stage1' 阶段的指定轮次)到来时,更新损失函数的配置。
        # 这种设计允许在训练过程中的不同阶段使用不同的损失函数配置,提供了灵活性和可调整性。
        if 'stage1' in self.cfg:
            if self.i_epoch == self.cfg.stage1.epoch:
                self.loss_func.cfg.update(self.cfg.stage1.loss)

        for i_step, data in enumerate(self.train_loader):
            if i_step > self.cfg.epoch_size:
                break
            # read data to device
            img1, img2 = data['img1'], data['img2']
            img_pair = torch.cat([img1, img2], 1).to(self.device)

            # measure data loading time
            am_data_time.update(time.time() - end)

            # compute output
            res_dict = self.model(img_pair, with_bk=True)
            flows_12, flows_21 = res_dict['flows_fw'], res_dict['flows_bw']
            flows = [torch.cat([flo12, flo21], 1) for flo12, flo21 in
                     zip(flows_12, flows_21)]
            # 获取损失
            loss, l_ph, l_sm, flow_mean = self.loss_func(flows, img_pair)

            # update meters
            key_meters.update([loss.item(), l_ph.item(), l_sm.item(), flow_mean.item()],
                              img_pair.size(0))

            # compute gradient and do optimization step
            self.optimizer.zero_grad()
            # loss.backward()

            # 防止梯度消失
            scaled_loss = 1024. * loss
            scaled_loss.backward()

            for param in [p for p in self.model.parameters() if p.requires_grad]:
                # 平衡数值(因为上面乘了1024)
                param.grad.data.mul_(1. / 1024)
            # 参数更新
            self.optimizer.step()

            # measure elapsed time
            am_batch_time.update(time.time() - end)
            end = time.time()

            # 在训练过程中记录和打印训练的一些信息
            # 训练轮数来记录信息
            if self.i_iter % self.cfg.record_freq == 0:
                # 录当前指标的值到 TensorBoard
                for v, name in zip(key_meters.val, key_meter_names):
                    self.summary_writer.add_scalar('Train_' + name, v, self.i_iter)
            # 训练轮数来打印信息
            if self.i_iter % self.cfg.print_freq == 0:
                istr = '{}:{:04d}/{:04d}'.format(
                    self.i_epoch, i_step, self.cfg.epoch_size) + \
                       ' Time {} Data {}'.format(am_batch_time, am_data_time) + \
                       ' Info {}'.format(key_meters)
                self._log.info(istr)

            self.i_iter += 1
        self.i_epoch += 1

之后我们就是要去看一下验证的代码,我们在前面写了一个验证的函数,下面我们会调用它的。

import time
import torch
import numpy as np
from .base_trainer import BaseTrainer
from utils.flow_utils import load_flow, evaluate_flow
from utils.misc_utils import AverageMeter


class TrainFramework(BaseTrainer):
    def __init__(self, train_loader, valid_loader, model, loss_func,
                 _log, save_root, config):
        super(TrainFramework, self).__init__(
            train_loader, valid_loader, model, loss_func, _log, save_root, config)

    def _run_one_epoch(self):
        am_batch_time = AverageMeter()
        am_data_time = AverageMeter()

        key_meter_names = ['Loss', 'l_ph', 'l_sm', 'flow_mean']
        key_meters = AverageMeter(i=len(key_meter_names), precision=4)

        self.model.train()
        end = time.time()

        if 'stage1' in self.cfg:
            if self.i_epoch == self.cfg.stage1.epoch:
                self.loss_func.cfg.update(self.cfg.stage1.loss)

        for i_step, data in enumerate(self.train_loader):
            if i_step > self.cfg.epoch_size:
                break
            # read data to device
            img1, img2 = data['img1'], data['img2']
            img_pair = torch.cat([img1, img2], 1).to(self.device)

            # measure data loading time
            am_data_time.update(time.time() - end)

            # compute output
            res_dict = self.model(img_pair, with_bk=True)
            flows_12, flows_21 = res_dict['flows_fw'], res_dict['flows_bw']
            flows = [torch.cat([flo12, flo21], 1) for flo12, flo21 in
                     zip(flows_12, flows_21)]
            loss, l_ph, l_sm, flow_mean = self.loss_func(flows, img_pair)

            # update meters
            key_meters.update([loss.item(), l_ph.item(), l_sm.item(), flow_mean.item()],
                              img_pair.size(0))

            # compute gradient and do optimization step
            self.optimizer.zero_grad()
            # loss.backward()

            scaled_loss = 1024. * loss
            scaled_loss.backward()

            for param in [p for p in self.model.parameters() if p.requires_grad]:
                param.grad.data.mul_(1. / 1024)

            self.optimizer.step()

            # measure elapsed time
            am_batch_time.update(time.time() - end)
            end = time.time()

            if self.i_iter % self.cfg.record_freq == 0:
                for v, name in zip(key_meters.val, key_meter_names):
                    self.summary_writer.add_scalar('Train_' + name, v, self.i_iter)

            if self.i_iter % self.cfg.print_freq == 0:
                istr = '{}:{:04d}/{:04d}'.format(
                    self.i_epoch, i_step, self.cfg.epoch_size) + \
                       ' Time {} Data {}'.format(am_batch_time, am_data_time) + \
                       ' Info {}'.format(key_meters)
                self._log.info(istr)

            self.i_iter += 1
        self.i_epoch += 1

    @torch.no_grad()
    def _validate_with_gt(self):
        batch_time = AverageMeter()

        if type(self.valid_loader) is not list:
            self.valid_loader = [self.valid_loader]

        # only use the first GPU to run validation, multiple GPUs might raise error.
        # https://github.com/Eromera/erfnet_pytorch/issues/2#issuecomment-486142360
        self.model = self.model.module
        self.model.eval()

        end = time.time()

        all_error_names = []
        all_error_avgs = []

        n_step = 0
        for i_set, loader in enumerate(self.valid_loader):
            error_names = ['EPE', 'E_noc', 'E_occ', 'F1_all']
            error_meters = AverageMeter(i=len(error_names))
            for i_step, data in enumerate(loader):
                img1, img2 = data['img1'], data['img2']
                img_pair = torch.cat([img1, img2], 1).to(self.device)

                res = list(map(load_flow, data['flow_occ']))
                gt_flows, occ_masks = [r[0] for r in res], [r[1] for r in res]
                res = list(map(load_flow, data['flow_noc']))
                _, noc_masks = [r[0] for r in res], [r[1] for r in res]

                gt_flows = [np.concatenate([flow, occ_mask, noc_mask], axis=2) for
                            flow, occ_mask, noc_mask in
                            zip(gt_flows, occ_masks, noc_masks)]

                # compute output
                flows = self.model(img_pair)['flows_fw']
                pred_flows = flows[0].detach().cpu().numpy().transpose([0, 2, 3, 1])

                es = evaluate_flow(gt_flows, pred_flows)
                error_meters.update([l.item() for l in es], img_pair.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i_step % self.cfg.print_freq == 0 or i_step == len(loader) - 1:
                    self._log.info('Test: {0}[{1}/{2}]\t Time {3}\t '.format(
                        i_set, i_step, self.cfg.valid_size, batch_time) + ' '.join(
                        map('{:.2f}'.format, error_meters.avg)))

                if i_step > self.cfg.valid_size:
                    break
            n_step += len(loader)

            # write error to tf board.
            for value, name in zip(error_meters.avg, error_names):
                self.summary_writer.add_scalar(
                    'Valid_{}_{}'.format(name, i_set), value, self.i_epoch)

            all_error_avgs.extend(error_meters.avg)
            all_error_names.extend(['{}_{}'.format(name, i_set) for name in error_names])

        self.model = torch.nn.DataParallel(self.model, device_ids=self.device_ids)
        # In order to reduce the space occupied during debugging,
        # only the model with more than cfg.save_iter iterations will be saved.
        if self.i_iter > self.cfg.save_iter:
            self.save_model(all_error_avgs[0], name='KITTI_Flow')

        return all_error_avgs, all_error_names

这里面验证集的话是一个类表,因为sintel和kitti都是两个。

kitti_trainer

运行一个epoch是一样的,这里仅更改验证

import time
import torch
import numpy as np
from .base_trainer import BaseTrainer
from utils.flow_utils import load_flow, evaluate_flow
from utils.misc_utils import AverageMeter


class TrainFramework(BaseTrainer):
    def __init__(self, train_loader, valid_loader, model, loss_func,
                 _log, save_root, config):
        super(TrainFramework, self).__init__(
            train_loader, valid_loader, model, loss_func, _log, save_root, config)

    def _run_one_epoch(self):
        am_batch_time = AverageMeter()
        am_data_time = AverageMeter()

        key_meter_names = ['Loss', 'l_ph', 'l_sm', 'flow_mean']
        key_meters = AverageMeter(i=len(key_meter_names), precision=4)

        self.model.train()
        end = time.time()

        if 'stage1' in self.cfg:
            if self.i_epoch == self.cfg.stage1.epoch:
                self.loss_func.cfg.update(self.cfg.stage1.loss)

        for i_step, data in enumerate(self.train_loader):
            if i_step > self.cfg.epoch_size:
                break
            # read data to device
            img1, img2 = data['img1'], data['img2']
            img_pair = torch.cat([img1, img2], 1).to(self.device)

            # measure data loading time
            am_data_time.update(time.time() - end)

            # compute output
            res_dict = self.model(img_pair, with_bk=True)
            flows_12, flows_21 = res_dict['flows_fw'], res_dict['flows_bw']
            flows = [torch.cat([flo12, flo21], 1) for flo12, flo21 in
                     zip(flows_12, flows_21)]
            loss, l_ph, l_sm, flow_mean = self.loss_func(flows, img_pair)

            # update meters
            key_meters.update([loss.item(), l_ph.item(), l_sm.item(), flow_mean.item()],
                              img_pair.size(0))

            # compute gradient and do optimization step
            self.optimizer.zero_grad()
            # loss.backward()

            scaled_loss = 1024. * loss
            scaled_loss.backward()

            for param in [p for p in self.model.parameters() if p.requires_grad]:
                param.grad.data.mul_(1. / 1024)

            self.optimizer.step()

            # measure elapsed time
            am_batch_time.update(time.time() - end)
            end = time.time()

            if self.i_iter % self.cfg.record_freq == 0:
                for v, name in zip(key_meters.val, key_meter_names):
                    self.summary_writer.add_scalar('Train_' + name, v, self.i_iter)

            if self.i_iter % self.cfg.print_freq == 0:
                istr = '{}:{:04d}/{:04d}'.format(
                    self.i_epoch, i_step, self.cfg.epoch_size) + \
                       ' Time {} Data {}'.format(am_batch_time, am_data_time) + \
                       ' Info {}'.format(key_meters)
                self._log.info(istr)

            self.i_iter += 1
        self.i_epoch += 1

    @torch.no_grad()
    def _validate_with_gt(self):
        batch_time = AverageMeter()

        if type(self.valid_loader) is not list:
            self.valid_loader = [self.valid_loader]

        # only use the first GPU to run validation, multiple GPUs might raise error.
        # https://github.com/Eromera/erfnet_pytorch/issues/2#issuecomment-486142360
        self.model = self.model.module
        self.model.eval()

        end = time.time()

        all_error_names = []
        all_error_avgs = []

        n_step = 0
        for i_set, loader in enumerate(self.valid_loader):
            error_names = ['EPE', 'E_noc', 'E_occ', 'F1_all']
            error_meters = AverageMeter(i=len(error_names))
            for i_step, data in enumerate(loader):
                img1, img2 = data['img1'], data['img2']
                img_pair = torch.cat([img1, img2], 1).to(self.device)

                res = list(map(load_flow, data['flow_occ']))
                # 获取真实光流和遮挡掩码
                gt_flows, occ_masks = [r[0] for r in res], [r[1] for r in res]
                res = list(map(load_flow, data['flow_noc']))
                # 获取无遮挡掩码
                _, noc_masks = [r[0] for r in res], [r[1] for r in res]
                # 使用 np.concatenate 函数将每个流场、occ掩码和noc掩码连接在一起,形成一个3通道的张量,并将所有结果存储在列表 gt_flows 中
                gt_flows = [np.concatenate([flow, occ_mask, noc_mask], axis=2) for
                            flow, occ_mask, noc_mask in
                            zip(gt_flows, occ_masks, noc_masks)]

                # compute output
                flows = self.model(img_pair)['flows_fw']
                pred_flows = flows[0].detach().cpu().numpy().transpose([0, 2, 3, 1])
                # 这里使用flow_shaoe = 4的情况
                es = evaluate_flow(gt_flows, pred_flows)
                error_meters.update([l.item() for l in es], img_pair.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i_step % self.cfg.print_freq == 0 or i_step == len(loader) - 1:
                    self._log.info('Test: {0}[{1}/{2}]\t Time {3}\t '.format(
                        i_set, i_step, self.cfg.valid_size, batch_time) + ' '.join(
                        map('{:.2f}'.format, error_meters.avg)))

                if i_step > self.cfg.valid_size:
                    break
            n_step += len(loader)

            # write error to tf board.
            for value, name in zip(error_meters.avg, error_names):
                self.summary_writer.add_scalar(
                    'Valid_{}_{}'.format(name, i_set), value, self.i_epoch)

            all_error_avgs.extend(error_meters.avg)
            all_error_names.extend(['{}_{}'.format(name, i_set) for name in error_names])

        self.model = torch.nn.DataParallel(self.model, device_ids=self.device_ids)
        # In order to reduce the space occupied during debugging,
        # only the model with more than cfg.save_iter iterations will be saved.
        if self.i_iter > self.cfg.save_iter:
            self.save_model(all_error_avgs[0], name='KITTI_Flow')

        return all_error_avgs, all_error_names

flychairs_trainer

这个先略过

损失函数

对于无监督光流估计来说,最主要的就是亮度损失,但是为了更好的精度,增加了平滑度损失等,这里面我们将常用的直接copy过来就行,ARFlow也是使用的是unflow的代码。这里的模块是可以直接copy,但是后面换模型的话,就是没有使用到金字塔,那么这里的代码岂不是也得改吗?

这里注意一下计算亮度损失不是只是单单的使用图片1减去(图片2经过warp后的图片1’),而是计算两者的L1距离,SSIM或者Ternary 损失,因为这些操作是已经被证明更加好用。并且平滑损失也是分一阶平滑和二阶平滑。

    # 计算亮度损失,参数为:缩放过的图片1,图片2经过wrap后的图片1,遮挡掩码
    def loss_photomatric(self, im1_scaled, im1_recons, occu_mask1):
        loss = []
        # 存在L1损失
        if self.cfg.w_l1 > 0:
            loss += [self.cfg.w_l1 * (im1_scaled - im1_recons).abs() * occu_mask1]
        # 存在SSIM损失
        if self.cfg.w_ssim > 0:
            loss += [self.cfg.w_ssim * SSIM(im1_recons * occu_mask1,
                                            im1_scaled * occu_mask1)]
        # 存在Ternary 损失
        if self.cfg.w_ternary > 0:
            loss += [self.cfg.w_ternary * TernaryLoss(im1_recons * occu_mask1,
                                                      im1_scaled * occu_mask1)]
        # 计算损失列表中每个项的均值,然后除以遮挡掩码的均值,以得到最终的光度损失
        return sum([l.mean() for l in loss]) / occu_mask1.mean()
    # 计算平滑损失
    def loss_smooth(self, flow, im1_scaled):
        # 计算二阶损失
        if 'smooth_2nd' in self.cfg and self.cfg.smooth_2nd:
            func_smooth = smooth_grad_2nd
        # 计算一阶损失
        else:
            func_smooth = smooth_grad_1st
        loss = []
        loss += [func_smooth(flow, im1_scaled, self.cfg.alpha)]
        return sum([l.mean() for l in loss])

到这里是一样的,就是后面涉及的到金字塔的地方后面可能不一样

    # 后期要去修改的地方
    def forward(self, output, target):
        """
        :param output: Multi-scale forward/backward flows n * [B x 4 x h x w]
        :param target: image pairs Nx6xHxW
        :return:
        """
        # 金字塔光流
        pyramid_flows = output
        # 图片对
        im1_origin = target[:, :3]
        im2_origin = target[:, 3:]
        # 金字塔平滑度损失
        pyramid_smooth_losses = []
        # 金字塔亮度损失
        pyramid_warp_losses = []
        # 金字塔遮挡掩码
        self.pyramid_occu_mask1 = []
        self.pyramid_occu_mask2 = []

        s = 1.
        # 对光流金字塔进行处理
        for i, flow in enumerate(pyramid_flows):
            # 如果该金字塔权重为0,则置0,并且跳过
            if self.cfg.w_scales[i] == 0:
                pyramid_warp_losses.append(0)
                pyramid_smooth_losses.append(0)
                continue

            # 获取光流尺寸
            b, _, h, w = flow.size()

            # resize images to match the size of layer
            im1_scaled = F.interpolate(im1_origin, (h, w), mode='area')
            im2_scaled = F.interpolate(im2_origin, (h, w), mode='area')
            # 光流warp进行img重建
            im1_recons = flow_warp(im2_scaled, flow[:, :2], pad=self.cfg.warp_pad)
            im2_recons = flow_warp(im1_scaled, flow[:, 2:], pad=self.cfg.warp_pad)

            # 金字塔层数为一,计算遮挡掩码
            if i == 0:
                # 使用from_back方法获取遮挡
                if self.cfg.occ_from_back:
                    occu_mask1 = 1 - get_occu_mask_backward(flow[:, 2:], th=0.2)
                    occu_mask2 = 1 - get_occu_mask_backward(flow[:, :2], th=0.2)
                # 使用双向估计获取遮挡
                else:
                    occu_mask1 = 1 - get_occu_mask_bidirection(flow[:, :2], flow[:, 2:])
                    occu_mask2 = 1 - get_occu_mask_bidirection(flow[:, 2:], flow[:, :2])
            # 对于其他层的遮挡掩码,使用最近临插值
            else:
                occu_mask1 = F.interpolate(self.pyramid_occu_mask1[0],
                                           (h, w), mode='nearest')
                occu_mask2 = F.interpolate(self.pyramid_occu_mask2[0],
                                           (h, w), mode='nearest')
            # 添加遮挡金字塔
            self.pyramid_occu_mask1.append(occu_mask1)
            self.pyramid_occu_mask2.append(occu_mask2)

            # 如果是金字塔第一层,计算一个尺度因子
            if i == 0:
                s = min(h, w)

            # 计算亮度损失
            loss_warp = self.loss_photomatric(im1_scaled, im1_recons, occu_mask1)
            # 计算平滑损失
            loss_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled)
            # 是否需要计算图像2的损失
            if self.cfg.with_bk:
                loss_warp += self.loss_photomatric(im2_scaled, im2_recons,
                                                   occu_mask2)
                loss_smooth += self.loss_smooth(flow[:, 2:] / s, im2_scaled)
                # 计算平均损失
                loss_warp /= 2.
                loss_smooth /= 2.

            # 亮度损失金字塔
            pyramid_warp_losses.append(loss_warp)
            # 平滑损失金字塔
            pyramid_smooth_losses.append(loss_smooth)

        # 对金字塔中每个层级的光度损失和平滑损失进行加权求和,最终得到总体的损失值
        pyramid_warp_losses = [l * w for l, w in
                               zip(pyramid_warp_losses, self.cfg.w_scales)]
        pyramid_smooth_losses = [l * w for l, w in
                                 zip(pyramid_smooth_losses, self.cfg.w_sm_scales)]

        warp_loss = sum(pyramid_warp_losses)
        smooth_loss = self.cfg.w_smooth * sum(pyramid_smooth_losses)
        total_loss = warp_loss + smooth_loss
        # 返回总体损失,亮度损失,平滑损失,金字塔中第一个层级的光流张量的绝对值的均值
        return total_loss, warp_loss, smooth_loss, pyramid_flows[0].abs().mean()
  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值