前言
我们这里直接使用PWCNet了,先了解一下基线和接口啥的,后面就可以换成自己可以创新的模型了。我们借助的ARFlow的工程,里面真的很贴心呀,没有使用AR的和使用AR的两个版本,我们这里自然就是使用前者嘛!因为其实我们要进行方法创新的时候也不能直接套用它的方法呀。
模型部分
这里其实最主要的就是输入一个图像序列,然后返回一个光流金字塔字典,我们这里直接copy的ARFlow的代码,详细的去看下工程,这里就不附上了,因为这个地方的变化是最大的。
训练部分:base
ARFlow的训练代码写的好清晰呀,和数据dateset一样,我们这里先创建一个base的类(可复用),老规矩直接附上代码注释。
import torch
import numpy as np
from abc import abstractmethod
from tensorboardX import SummaryWriter
from utils.torch_utils import bias_parameters, weight_parameters, \
load_checkpoint, save_checkpoint, AdamW
class BaseTrainer:
"""
Base class for all trainers
"""
def __init__(self, train_loader, valid_loader, model, loss_func,
_log, save_root, config):
self._log = _log
# 配置文件
self.cfg = config
# 保存目录
self.save_root = save_root
# tensorboardX对象(可视化工具)
self.summary_writer = SummaryWriter(str(save_root))
# 数据dataset
self.train_loader, self.valid_loader = train_loader, valid_loader
# 设备
self.device, self.device_ids = self._prepare_device(config['n_gpu'])
# 模型实例
self.model = self._init_model(model)
# 优化器实例
self.optimizer = self._create_optimizer()
# 损失函数
self.loss_func = loss_func
# 当错误小于这个,就停止训练
self.best_error = np.inf
self.i_epoch = 0
self.i_iter = 0
@abstractmethod
def _run_one_epoch(self):
# 训练一个epoch
...
@abstractmethod
def _validate_with_gt(self):
# 验证函数
...
def train(self):
for epoch in range(self.cfg.epoch_num):
# 运行一个训练周期
self._run_one_epoch()
# 判断是否到达进行验证的时机
if self.i_epoch % self.cfg.val_epoch_size == 0:
# 使用真实标签进行验证,获取验证集上的性能指标
errors, error_names = self._validate_with_gt()
# 格式化输出验证结果
valid_res = ' '.join(
'{}: {:.2f}'.format(*t) for t in zip(error_names, errors))
# 打印验证结果信息
self._log.info(' * Epoch {} '.format(self.i_epoch) + valid_res)
# 初始化模型
def _init_model(self, model):
model = model.to(self.device)
# 使用预训练模型权重
if self.cfg.pretrained_model:
self._log.info("=> using pre-trained weights {}.".format(
self.cfg.pretrained_model))
# 加载模型权重
epoch, weights = load_checkpoint(self.cfg.pretrained_model)
from collections import OrderedDict
# 创建一个有序字典
new_weights = OrderedDict()
# model.state_dict()获取模型的当前状态字典,包含了模型的所有权重参数
# 获取状态字典的名称列表
model_keys = list(model.state_dict().keys())
# 获取权重名称列表
weight_keys = list(weights.keys())
# 将预训练模型状态字典和参数字典打包
for a, b in zip(model_keys, weight_keys):
new_weights[a] = weights[b]
weights = new_weights
# 模型加载预训练权重
model.load_state_dict(weights)
# 从头开始训练模型
else:
self._log.info("=> Train from scratch.")
# 模型初始化
model.init_weights()
# 多GPU设置
model = torch.nn.DataParallel(model, device_ids=self.device_ids)
return model
# 类内操作:创建优化器
def _create_optimizer(self):
self._log.info('=> setting Adam solver')
param_groups = [
{'params': bias_parameters(self.model.module),
'weight_decay': self.cfg.bias_decay},
{'params': weight_parameters(self.model.module),
'weight_decay': self.cfg.weight_decay}]
if self.cfg.optim == 'adamw':
optimizer = AdamW(param_groups, self.cfg.lr,
betas=(self.cfg.momentum, self.cfg.beta))
elif self.cfg.optim == 'adam':
optimizer = torch.optim.Adam(param_groups, self.cfg.lr,
betas=(self.cfg.momentum, self.cfg.beta),
eps=1e-7)
else:
raise NotImplementedError(self.cfg.optim)
return optimizer
def _prepare_device(self, n_gpu_use):
"""
setup GPU device if available, move model into configured device
"""
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
self._log.warning("Warning: There\'s no GPU available on this machine,"
"training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
self._log.warning(
"Warning: The number of GPU\'s configured to use is {}, "
"but only {} are available.".format(n_gpu_use, n_gpu))
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids
# 保存模型
def save_model(self, error, name):
is_best = error < self.best_error
if is_best:
self.best_error = error
models = {'epoch': self.i_epoch,
'state_dict': self.model.module.state_dict()}
save_checkpoint(self.save_root, models, name, is_best)
有些代码是在ARFlow工程中的utils中。我将上面要用到的代码直接复制下
# torch_utils
import torch
import shutil
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import numbers
import random
import math
from torch.optim import Optimizer
def init_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def weight_parameters(module):
return [param for name, param in module.named_parameters() if 'weight' in name]
def bias_parameters(module):
return [param for name, param in module.named_parameters() if 'bias' in name]
def load_checkpoint(model_path):
weights = torch.load(model_path)
epoch = None
if 'epoch' in weights:
epoch = weights.pop('epoch')
if 'state_dict' in weights:
state_dict = (weights['state_dict'])
else:
state_dict = weights
return epoch, state_dict
def save_checkpoint(save_path, states, file_prefixes, is_best, filename='ckpt.pth.tar'):
def run_one_sample(save_path, state, prefix, is_best, filename):
torch.save(state, save_path / '{}_{}'.format(prefix, filename))
if is_best:
shutil.copyfile(save_path / '{}_{}'.format(prefix, filename),
save_path / '{}_model_best.pth.tar'.format(prefix))
if not isinstance(file_prefixes, str):
for (prefix, state) in zip(file_prefixes, states):
run_one_sample(save_path, state, prefix, is_best, filename)
else:
run_one_sample(save_path, states, file_prefixes, is_best, filename)
def restore_model(model, pretrained_file):
epoch, weights = load_checkpoint(pretrained_file)
model_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())
# load weights by name
weights_not_in_model = sorted(list(weight_keys - model_keys))
model_not_in_weights = sorted(list(model_keys - weight_keys))
if len(model_not_in_weights):
print('Warning: There are weights in model but not in pre-trained.')
for key in (model_not_in_weights):
print(key)
weights[key] = model.state_dict()[key]
if len(weights_not_in_model):
print('Warning: There are pre-trained weights not in model.')
for key in (weights_not_in_model):
print(key)
from collections import OrderedDict
new_weights = OrderedDict()
for key in model_keys:
new_weights[key] = weights[key]
weights = new_weights
model.load_state_dict(weights)
return model
class AdamW(Optimizer):
"""Implements AdamW algorithm.
It has been proposed in `Fixing Weight Decay Regularization in Adam`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
.. Fixing Weight Decay Regularization in Adam:
https://arxiv.org/abs/1711.05101
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
super(AdamW, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'AdamW does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# according to the paper, this penalty should come after the bias correction
# if group['weight_decay'] != 0:
# grad = grad.add(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)
return loss
我们这里得去学习一下光流预测,是写在了一个函数中:
def evaluate_flow(gt_flows, pred_flows, moving_masks=None):
# credit "undepthflow/eval/evaluate_flow.py"
# 计算误差率
def calculate_error_rate(epe_map, gt_flow, mask):
# 返回那些误差太大的像素mask
bad_pixels = np.logical_and(
# 生成一个布尔类型的数组,找出误差图误差大于3的像素点
epe_map * mask > 3,
# 将光流估计误差图 epe_map 中的误差值标准化,并且取>0.05的像素点
epe_map * mask / np.maximum(
# 求L2范数,并且与1e-10这个很小的常数取最值,防止除0
np.sqrt(np.sum(np.square(gt_flow), axis=2)), 1e-10) > 0.05)
# 计算了被定义为 "bad pixels" 的像素在总像素数中的百分比,从而得到一个误差率的度量:%。
return bad_pixels.sum() / mask.sum() * 100.
# 总体误差、非遮挡误差、遮挡误差、运动误差、静态误差和总体误差率
error, error_noc, error_occ, error_move, error_static, error_rate = \
0.0, 0.0, 0.0, 0.0, 0.0, 0.0
# 运动误差率和静态误差率
error_move_rate, error_static_rate = 0.0, 0.0
# 真值光流的数量
B = len(gt_flows)
for gt_flow, pred_flow, i in zip(gt_flows, pred_flows, range(B)):
# 获取预测光流和真值光流的尺寸
H, W = gt_flow.shape[:2]
h, w = pred_flow.shape[:2]
# 其内容与原始对象完全相同,但它们在内存中是独立的
pred_flow = np.copy(pred_flow)
# 将光流从原始图像的尺寸缩放到新的图像尺寸,确保光流信息仍然对应于新的图像大小
pred_flow[:, :, 0] = pred_flow[:, :, 0] / w * W
pred_flow[:, :, 1] = pred_flow[:, :, 1] / h * H
# 线性插值 (cv2.INTER_LINEAR),以平滑地调整图像大小。
flo_pred = cv2.resize(pred_flow, (W, H), interpolation=cv2.INTER_LINEAR)
# 储了光流估计结果与真实光流之间每个像素位置上的端点误差
epe_map = np.sqrt(
np.sum(np.square(flo_pred[:, :, :2] - gt_flow[:, :, :2]),
# 在第二轴上求和为了生成map
axis=2))
if gt_flow.shape[-1] == 2:
# 累加当前图像的平均端点误差到 error 变量中
error += np.mean(epe_map)
elif gt_flow.shape[-1] == 4:
# 计算occ_mask的epe
error += np.sum(epe_map * gt_flow[:, :, 2]) / np.sum(gt_flow[:, :, 2])
# 获取没有遮挡的掩码
noc_mask = gt_flow[:, :, -1]
# 计算非遮挡误差
error_noc += np.sum(epe_map * noc_mask) / np.sum(noc_mask)
# 计算遮挡误差
error_occ += np.sum(epe_map * (gt_flow[:, :, 2] - noc_mask)) / max(
np.sum(gt_flow[:, :, 2] - noc_mask), 1.0)
# 计算错误率
error_rate += calculate_error_rate(epe_map, gt_flow[:, :, 0:2],
gt_flow[:, :, 2])
# 计算其他误差
if moving_masks is not None:
move_mask = moving_masks[i]
error_move_rate += calculate_error_rate(
epe_map, gt_flow[:, :, 0:2], gt_flow[:, :, 2] * move_mask)
error_static_rate += calculate_error_rate(
epe_map, gt_flow[:, :, 0:2],
gt_flow[:, :, 2] * (1.0 - move_mask))
error_move += np.sum(epe_map * gt_flow[:, :, 2] *
move_mask) / np.sum(gt_flow[:, :, 2] *
move_mask)
error_static += np.sum(epe_map * gt_flow[:, :, 2] * (
1.0 - move_mask)) / np.sum(gt_flow[:, :, 2] *
(1.0 - move_mask))
# 根据光流场的表示方式,返回相应的误差和误差率
if gt_flows[0].shape[-1] == 4:
res = [error / B, error_noc / B, error_occ / B, error_rate / B]
if moving_masks is not None:
res += [error_move / B, error_static / B]
return res
else:
return [error / B]
这之后我们还得看一下计算和存储平均值的类 AverageMeter,用于跟踪一个或多个指标的平均值和当前值。
import collections
# 跟新字典
def update_dict(orig_dict, new_dict):
for key, val in new_dict.items():
if isinstance(val, collections.Mapping):
tmp = update_dict(orig_dict.get(key, {}), val)
orig_dict[key] = tmp
else:
orig_dict[key] = val
return orig_dict
class AverageMeter(object):
"""计算和存储平均值"""
def __init__(self, i=1, precision=3, names=None):
# 指标数
self.meters = i
# 小数点精度
self.precision = precision
# 重置
self.reset(self.meters)
# 指标名称
self.names = names
if names is not None:
assert self.meters == len(self.names)
else:
self.names = [''] * self.meters
# 数值重置函数
def reset(self, i):
# 初始化为一个长度为 i 的列表,所有元素都设置为0
self.val = [0] * i
self.avg = [0] * i
self.sum = [0] * i
self.count = [0] * i
# 更新数值
def update(self, val, n=1):
if not isinstance(val, list):
val = [val]
if not isinstance(n, list):
n = [n] * self.meters
assert (len(val) == self.meters and len(n) == self.meters)
for i in range(self.meters):
self.count[i] += n[i]
for i, v in enumerate(val):
self.val[i] = v
self.sum[i] += v * n[i]
self.avg[i] = self.sum[i] / self.count[i]
# 定义对象的字符串表示形式
def __repr__(self):
val = ' '.join(['{} {:.{}f}'.format(n, v, self.precision) for n, v in
zip(self.names, self.val)])
avg = ' '.join(['{} {:.{}f}'.format(n, a, self.precision) for n, a in
zip(self.names, self.avg)])
return '{} ({})'.format(val, avg)
在Python中,repr 是一个特殊方法,用于定义对象的字符串表示形式。当你使用 print(obj) 或 str(obj) 时,Python 将调用该对象的 repr 方法来获取对象的字符串表示。
sintel_trainer
我们先创建一个类
import time
import torch
from .base_trainer import BaseTrainer
from utils.flow_utils import evaluate_flow
from utils.misc_utils import AverageMeter
class TrainFramework(BaseTrainer):
def __init__(self, train_loader, valid_loader, model, loss_func,
_log, save_root, config):
super(TrainFramework, self).__init__(
train_loader, valid_loader, model, loss_func, _log, save_root, config)
之后在base中我们写了循环epoch的代码,那么这里我们就要去写一次epoch怎么训练就好了,这个其实也不是固定的,根据不同的任务,要去书写不同的代码,比如ARFlow中的两次前向传播的过程,就和我们下面的写的基础训练不一样的,如果我们后面要进行创新,这里面也是要自己去书写的。
def _run_one_epoch(self):
# 创建两个指标对象:batch和时间
am_batch_time = AverageMeter()
am_data_time = AverageMeter()
# 创建关键指标名称:总loss,亮度损失,平滑损失,光流平均
key_meter_names = ['Loss', 'l_ph', 'l_sm', 'flow_mean']
# 创建4个关键指标对象
key_meters = AverageMeter(i=len(key_meter_names), precision=4)
# 开启训练模式
self.model.train()
# 记录时间
end = time.time()
# 这段代码的作用是在训练的特定轮次(即 'stage1' 阶段的指定轮次)到来时,更新损失函数的配置。
# 这种设计允许在训练过程中的不同阶段使用不同的损失函数配置,提供了灵活性和可调整性。
if 'stage1' in self.cfg:
if self.i_epoch == self.cfg.stage1.epoch:
self.loss_func.cfg.update(self.cfg.stage1.loss)
for i_step, data in enumerate(self.train_loader):
if i_step > self.cfg.epoch_size:
break
# read data to device
img1, img2 = data['img1'], data['img2']
img_pair = torch.cat([img1, img2], 1).to(self.device)
# measure data loading time
am_data_time.update(time.time() - end)
# compute output
res_dict = self.model(img_pair, with_bk=True)
flows_12, flows_21 = res_dict['flows_fw'], res_dict['flows_bw']
flows = [torch.cat([flo12, flo21], 1) for flo12, flo21 in
zip(flows_12, flows_21)]
# 获取损失
loss, l_ph, l_sm, flow_mean = self.loss_func(flows, img_pair)
# update meters
key_meters.update([loss.item(), l_ph.item(), l_sm.item(), flow_mean.item()],
img_pair.size(0))
# compute gradient and do optimization step
self.optimizer.zero_grad()
# loss.backward()
# 防止梯度消失
scaled_loss = 1024. * loss
scaled_loss.backward()
for param in [p for p in self.model.parameters() if p.requires_grad]:
# 平衡数值(因为上面乘了1024)
param.grad.data.mul_(1. / 1024)
# 参数更新
self.optimizer.step()
# measure elapsed time
am_batch_time.update(time.time() - end)
end = time.time()
# 在训练过程中记录和打印训练的一些信息
# 训练轮数来记录信息
if self.i_iter % self.cfg.record_freq == 0:
# 录当前指标的值到 TensorBoard
for v, name in zip(key_meters.val, key_meter_names):
self.summary_writer.add_scalar('Train_' + name, v, self.i_iter)
# 训练轮数来打印信息
if self.i_iter % self.cfg.print_freq == 0:
istr = '{}:{:04d}/{:04d}'.format(
self.i_epoch, i_step, self.cfg.epoch_size) + \
' Time {} Data {}'.format(am_batch_time, am_data_time) + \
' Info {}'.format(key_meters)
self._log.info(istr)
self.i_iter += 1
self.i_epoch += 1
之后我们就是要去看一下验证的代码,我们在前面写了一个验证的函数,下面我们会调用它的。
import time
import torch
import numpy as np
from .base_trainer import BaseTrainer
from utils.flow_utils import load_flow, evaluate_flow
from utils.misc_utils import AverageMeter
class TrainFramework(BaseTrainer):
def __init__(self, train_loader, valid_loader, model, loss_func,
_log, save_root, config):
super(TrainFramework, self).__init__(
train_loader, valid_loader, model, loss_func, _log, save_root, config)
def _run_one_epoch(self):
am_batch_time = AverageMeter()
am_data_time = AverageMeter()
key_meter_names = ['Loss', 'l_ph', 'l_sm', 'flow_mean']
key_meters = AverageMeter(i=len(key_meter_names), precision=4)
self.model.train()
end = time.time()
if 'stage1' in self.cfg:
if self.i_epoch == self.cfg.stage1.epoch:
self.loss_func.cfg.update(self.cfg.stage1.loss)
for i_step, data in enumerate(self.train_loader):
if i_step > self.cfg.epoch_size:
break
# read data to device
img1, img2 = data['img1'], data['img2']
img_pair = torch.cat([img1, img2], 1).to(self.device)
# measure data loading time
am_data_time.update(time.time() - end)
# compute output
res_dict = self.model(img_pair, with_bk=True)
flows_12, flows_21 = res_dict['flows_fw'], res_dict['flows_bw']
flows = [torch.cat([flo12, flo21], 1) for flo12, flo21 in
zip(flows_12, flows_21)]
loss, l_ph, l_sm, flow_mean = self.loss_func(flows, img_pair)
# update meters
key_meters.update([loss.item(), l_ph.item(), l_sm.item(), flow_mean.item()],
img_pair.size(0))
# compute gradient and do optimization step
self.optimizer.zero_grad()
# loss.backward()
scaled_loss = 1024. * loss
scaled_loss.backward()
for param in [p for p in self.model.parameters() if p.requires_grad]:
param.grad.data.mul_(1. / 1024)
self.optimizer.step()
# measure elapsed time
am_batch_time.update(time.time() - end)
end = time.time()
if self.i_iter % self.cfg.record_freq == 0:
for v, name in zip(key_meters.val, key_meter_names):
self.summary_writer.add_scalar('Train_' + name, v, self.i_iter)
if self.i_iter % self.cfg.print_freq == 0:
istr = '{}:{:04d}/{:04d}'.format(
self.i_epoch, i_step, self.cfg.epoch_size) + \
' Time {} Data {}'.format(am_batch_time, am_data_time) + \
' Info {}'.format(key_meters)
self._log.info(istr)
self.i_iter += 1
self.i_epoch += 1
@torch.no_grad()
def _validate_with_gt(self):
batch_time = AverageMeter()
if type(self.valid_loader) is not list:
self.valid_loader = [self.valid_loader]
# only use the first GPU to run validation, multiple GPUs might raise error.
# https://github.com/Eromera/erfnet_pytorch/issues/2#issuecomment-486142360
self.model = self.model.module
self.model.eval()
end = time.time()
all_error_names = []
all_error_avgs = []
n_step = 0
for i_set, loader in enumerate(self.valid_loader):
error_names = ['EPE', 'E_noc', 'E_occ', 'F1_all']
error_meters = AverageMeter(i=len(error_names))
for i_step, data in enumerate(loader):
img1, img2 = data['img1'], data['img2']
img_pair = torch.cat([img1, img2], 1).to(self.device)
res = list(map(load_flow, data['flow_occ']))
gt_flows, occ_masks = [r[0] for r in res], [r[1] for r in res]
res = list(map(load_flow, data['flow_noc']))
_, noc_masks = [r[0] for r in res], [r[1] for r in res]
gt_flows = [np.concatenate([flow, occ_mask, noc_mask], axis=2) for
flow, occ_mask, noc_mask in
zip(gt_flows, occ_masks, noc_masks)]
# compute output
flows = self.model(img_pair)['flows_fw']
pred_flows = flows[0].detach().cpu().numpy().transpose([0, 2, 3, 1])
es = evaluate_flow(gt_flows, pred_flows)
error_meters.update([l.item() for l in es], img_pair.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i_step % self.cfg.print_freq == 0 or i_step == len(loader) - 1:
self._log.info('Test: {0}[{1}/{2}]\t Time {3}\t '.format(
i_set, i_step, self.cfg.valid_size, batch_time) + ' '.join(
map('{:.2f}'.format, error_meters.avg)))
if i_step > self.cfg.valid_size:
break
n_step += len(loader)
# write error to tf board.
for value, name in zip(error_meters.avg, error_names):
self.summary_writer.add_scalar(
'Valid_{}_{}'.format(name, i_set), value, self.i_epoch)
all_error_avgs.extend(error_meters.avg)
all_error_names.extend(['{}_{}'.format(name, i_set) for name in error_names])
self.model = torch.nn.DataParallel(self.model, device_ids=self.device_ids)
# In order to reduce the space occupied during debugging,
# only the model with more than cfg.save_iter iterations will be saved.
if self.i_iter > self.cfg.save_iter:
self.save_model(all_error_avgs[0], name='KITTI_Flow')
return all_error_avgs, all_error_names
这里面验证集的话是一个类表,因为sintel和kitti都是两个。
kitti_trainer
运行一个epoch是一样的,这里仅更改验证
import time
import torch
import numpy as np
from .base_trainer import BaseTrainer
from utils.flow_utils import load_flow, evaluate_flow
from utils.misc_utils import AverageMeter
class TrainFramework(BaseTrainer):
def __init__(self, train_loader, valid_loader, model, loss_func,
_log, save_root, config):
super(TrainFramework, self).__init__(
train_loader, valid_loader, model, loss_func, _log, save_root, config)
def _run_one_epoch(self):
am_batch_time = AverageMeter()
am_data_time = AverageMeter()
key_meter_names = ['Loss', 'l_ph', 'l_sm', 'flow_mean']
key_meters = AverageMeter(i=len(key_meter_names), precision=4)
self.model.train()
end = time.time()
if 'stage1' in self.cfg:
if self.i_epoch == self.cfg.stage1.epoch:
self.loss_func.cfg.update(self.cfg.stage1.loss)
for i_step, data in enumerate(self.train_loader):
if i_step > self.cfg.epoch_size:
break
# read data to device
img1, img2 = data['img1'], data['img2']
img_pair = torch.cat([img1, img2], 1).to(self.device)
# measure data loading time
am_data_time.update(time.time() - end)
# compute output
res_dict = self.model(img_pair, with_bk=True)
flows_12, flows_21 = res_dict['flows_fw'], res_dict['flows_bw']
flows = [torch.cat([flo12, flo21], 1) for flo12, flo21 in
zip(flows_12, flows_21)]
loss, l_ph, l_sm, flow_mean = self.loss_func(flows, img_pair)
# update meters
key_meters.update([loss.item(), l_ph.item(), l_sm.item(), flow_mean.item()],
img_pair.size(0))
# compute gradient and do optimization step
self.optimizer.zero_grad()
# loss.backward()
scaled_loss = 1024. * loss
scaled_loss.backward()
for param in [p for p in self.model.parameters() if p.requires_grad]:
param.grad.data.mul_(1. / 1024)
self.optimizer.step()
# measure elapsed time
am_batch_time.update(time.time() - end)
end = time.time()
if self.i_iter % self.cfg.record_freq == 0:
for v, name in zip(key_meters.val, key_meter_names):
self.summary_writer.add_scalar('Train_' + name, v, self.i_iter)
if self.i_iter % self.cfg.print_freq == 0:
istr = '{}:{:04d}/{:04d}'.format(
self.i_epoch, i_step, self.cfg.epoch_size) + \
' Time {} Data {}'.format(am_batch_time, am_data_time) + \
' Info {}'.format(key_meters)
self._log.info(istr)
self.i_iter += 1
self.i_epoch += 1
@torch.no_grad()
def _validate_with_gt(self):
batch_time = AverageMeter()
if type(self.valid_loader) is not list:
self.valid_loader = [self.valid_loader]
# only use the first GPU to run validation, multiple GPUs might raise error.
# https://github.com/Eromera/erfnet_pytorch/issues/2#issuecomment-486142360
self.model = self.model.module
self.model.eval()
end = time.time()
all_error_names = []
all_error_avgs = []
n_step = 0
for i_set, loader in enumerate(self.valid_loader):
error_names = ['EPE', 'E_noc', 'E_occ', 'F1_all']
error_meters = AverageMeter(i=len(error_names))
for i_step, data in enumerate(loader):
img1, img2 = data['img1'], data['img2']
img_pair = torch.cat([img1, img2], 1).to(self.device)
res = list(map(load_flow, data['flow_occ']))
# 获取真实光流和遮挡掩码
gt_flows, occ_masks = [r[0] for r in res], [r[1] for r in res]
res = list(map(load_flow, data['flow_noc']))
# 获取无遮挡掩码
_, noc_masks = [r[0] for r in res], [r[1] for r in res]
# 使用 np.concatenate 函数将每个流场、occ掩码和noc掩码连接在一起,形成一个3通道的张量,并将所有结果存储在列表 gt_flows 中
gt_flows = [np.concatenate([flow, occ_mask, noc_mask], axis=2) for
flow, occ_mask, noc_mask in
zip(gt_flows, occ_masks, noc_masks)]
# compute output
flows = self.model(img_pair)['flows_fw']
pred_flows = flows[0].detach().cpu().numpy().transpose([0, 2, 3, 1])
# 这里使用flow_shaoe = 4的情况
es = evaluate_flow(gt_flows, pred_flows)
error_meters.update([l.item() for l in es], img_pair.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i_step % self.cfg.print_freq == 0 or i_step == len(loader) - 1:
self._log.info('Test: {0}[{1}/{2}]\t Time {3}\t '.format(
i_set, i_step, self.cfg.valid_size, batch_time) + ' '.join(
map('{:.2f}'.format, error_meters.avg)))
if i_step > self.cfg.valid_size:
break
n_step += len(loader)
# write error to tf board.
for value, name in zip(error_meters.avg, error_names):
self.summary_writer.add_scalar(
'Valid_{}_{}'.format(name, i_set), value, self.i_epoch)
all_error_avgs.extend(error_meters.avg)
all_error_names.extend(['{}_{}'.format(name, i_set) for name in error_names])
self.model = torch.nn.DataParallel(self.model, device_ids=self.device_ids)
# In order to reduce the space occupied during debugging,
# only the model with more than cfg.save_iter iterations will be saved.
if self.i_iter > self.cfg.save_iter:
self.save_model(all_error_avgs[0], name='KITTI_Flow')
return all_error_avgs, all_error_names
flychairs_trainer
这个先略过
损失函数
对于无监督光流估计来说,最主要的就是亮度损失,但是为了更好的精度,增加了平滑度损失等,这里面我们将常用的直接copy过来就行,ARFlow也是使用的是unflow的代码。这里的模块是可以直接copy,但是后面换模型的话,就是没有使用到金字塔,那么这里的代码岂不是也得改吗?
这里注意一下计算亮度损失不是只是单单的使用图片1减去(图片2经过warp后的图片1’),而是计算两者的L1距离,SSIM或者Ternary 损失,因为这些操作是已经被证明更加好用。并且平滑损失也是分一阶平滑和二阶平滑。
# 计算亮度损失,参数为:缩放过的图片1,图片2经过wrap后的图片1,遮挡掩码
def loss_photomatric(self, im1_scaled, im1_recons, occu_mask1):
loss = []
# 存在L1损失
if self.cfg.w_l1 > 0:
loss += [self.cfg.w_l1 * (im1_scaled - im1_recons).abs() * occu_mask1]
# 存在SSIM损失
if self.cfg.w_ssim > 0:
loss += [self.cfg.w_ssim * SSIM(im1_recons * occu_mask1,
im1_scaled * occu_mask1)]
# 存在Ternary 损失
if self.cfg.w_ternary > 0:
loss += [self.cfg.w_ternary * TernaryLoss(im1_recons * occu_mask1,
im1_scaled * occu_mask1)]
# 计算损失列表中每个项的均值,然后除以遮挡掩码的均值,以得到最终的光度损失
return sum([l.mean() for l in loss]) / occu_mask1.mean()
# 计算平滑损失
def loss_smooth(self, flow, im1_scaled):
# 计算二阶损失
if 'smooth_2nd' in self.cfg and self.cfg.smooth_2nd:
func_smooth = smooth_grad_2nd
# 计算一阶损失
else:
func_smooth = smooth_grad_1st
loss = []
loss += [func_smooth(flow, im1_scaled, self.cfg.alpha)]
return sum([l.mean() for l in loss])
到这里是一样的,就是后面涉及的到金字塔的地方后面可能不一样
# 后期要去修改的地方
def forward(self, output, target):
"""
:param output: Multi-scale forward/backward flows n * [B x 4 x h x w]
:param target: image pairs Nx6xHxW
:return:
"""
# 金字塔光流
pyramid_flows = output
# 图片对
im1_origin = target[:, :3]
im2_origin = target[:, 3:]
# 金字塔平滑度损失
pyramid_smooth_losses = []
# 金字塔亮度损失
pyramid_warp_losses = []
# 金字塔遮挡掩码
self.pyramid_occu_mask1 = []
self.pyramid_occu_mask2 = []
s = 1.
# 对光流金字塔进行处理
for i, flow in enumerate(pyramid_flows):
# 如果该金字塔权重为0,则置0,并且跳过
if self.cfg.w_scales[i] == 0:
pyramid_warp_losses.append(0)
pyramid_smooth_losses.append(0)
continue
# 获取光流尺寸
b, _, h, w = flow.size()
# resize images to match the size of layer
im1_scaled = F.interpolate(im1_origin, (h, w), mode='area')
im2_scaled = F.interpolate(im2_origin, (h, w), mode='area')
# 光流warp进行img重建
im1_recons = flow_warp(im2_scaled, flow[:, :2], pad=self.cfg.warp_pad)
im2_recons = flow_warp(im1_scaled, flow[:, 2:], pad=self.cfg.warp_pad)
# 金字塔层数为一,计算遮挡掩码
if i == 0:
# 使用from_back方法获取遮挡
if self.cfg.occ_from_back:
occu_mask1 = 1 - get_occu_mask_backward(flow[:, 2:], th=0.2)
occu_mask2 = 1 - get_occu_mask_backward(flow[:, :2], th=0.2)
# 使用双向估计获取遮挡
else:
occu_mask1 = 1 - get_occu_mask_bidirection(flow[:, :2], flow[:, 2:])
occu_mask2 = 1 - get_occu_mask_bidirection(flow[:, 2:], flow[:, :2])
# 对于其他层的遮挡掩码,使用最近临插值
else:
occu_mask1 = F.interpolate(self.pyramid_occu_mask1[0],
(h, w), mode='nearest')
occu_mask2 = F.interpolate(self.pyramid_occu_mask2[0],
(h, w), mode='nearest')
# 添加遮挡金字塔
self.pyramid_occu_mask1.append(occu_mask1)
self.pyramid_occu_mask2.append(occu_mask2)
# 如果是金字塔第一层,计算一个尺度因子
if i == 0:
s = min(h, w)
# 计算亮度损失
loss_warp = self.loss_photomatric(im1_scaled, im1_recons, occu_mask1)
# 计算平滑损失
loss_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled)
# 是否需要计算图像2的损失
if self.cfg.with_bk:
loss_warp += self.loss_photomatric(im2_scaled, im2_recons,
occu_mask2)
loss_smooth += self.loss_smooth(flow[:, 2:] / s, im2_scaled)
# 计算平均损失
loss_warp /= 2.
loss_smooth /= 2.
# 亮度损失金字塔
pyramid_warp_losses.append(loss_warp)
# 平滑损失金字塔
pyramid_smooth_losses.append(loss_smooth)
# 对金字塔中每个层级的光度损失和平滑损失进行加权求和,最终得到总体的损失值
pyramid_warp_losses = [l * w for l, w in
zip(pyramid_warp_losses, self.cfg.w_scales)]
pyramid_smooth_losses = [l * w for l, w in
zip(pyramid_smooth_losses, self.cfg.w_sm_scales)]
warp_loss = sum(pyramid_warp_losses)
smooth_loss = self.cfg.w_smooth * sum(pyramid_smooth_losses)
total_loss = warp_loss + smooth_loss
# 返回总体损失,亮度损失,平滑损失,金字塔中第一个层级的光流张量的绝对值的均值
return total_loss, warp_loss, smooth_loss, pyramid_flows[0].abs().mean()