小参数视频插帧IFRNet解读和代码复现

B站视频讲解

概述

此论文标题是《IFRNet: Intermediate Feature Refine Network for Efficient Frame Interpolation》,收录于CVPR 2022,研究的任务是视频插帧,本文方法主要的特点在于模型非常的小和高效。

论文摘要

流行的视频帧插值算法,从连续的输入中生成中间帧,通常依赖于复杂的模型架构以及大量的参数或较大的延迟,这阻碍了它们在各种实时应用中的使用。在这项工作中,我们设计了一个高效的基于编码器-解码器的网络,称为IFRNet,用于快速合成中间帧。它首先从给定的输入中提取金字塔特征,然后与强大的中间特征一起细化双边中间流场,直到生成所需的输出。逐步细化的中间特征不仅可以促进中间流的估计,还可以补偿上下文细节,使IFRNet不需要额外的合成或细化模块。为了充分发挥其潜力,我们进一步提出了一种新颖的任务导向的光流蒸馏损失,专注于学习对帧合成有用的教师知识。同时,对逐步细化的中间特征施加了一个新的几何一致性正则化项,以保持更好的结构布局。在各种基准测试上的实验表明了所提出方法的卓越性能和快速推理速度。代码可在 https://github.com/ltkong218/IFRNet 获取。

创新点

  • 我们设计了一种新颖的IFRNet,用于同时进行中间流估计和中间特征细化,以实现高效的视频帧插值.

  • 新提出了任务导向的流蒸馏损失(Task-oriented flow distillation loss)和特征空间几何一致性损失,分别用于促进IFRNet的中间运动估计和中间特征重建.

  • 基准测试结果表明,我们的IFRNet不仅实现了最先进的VFI精度,还具有快速的推理速度和轻量级的模型大小.

核心方法

IFRNet

网络总流程架构如下图所示:
在这里插入图片描述

对于输入的两帧图像,先由一个金字塔编码器提取出不同分辨率的特征,帧0的 ϕ 0 1 , ϕ 0 2 , ϕ 0 3 , ϕ 0 4 \phi_0^1,\phi_0^2,\phi_0^3,\phi_0^4 ϕ01,ϕ02,ϕ03,ϕ04和帧1的 ϕ 1 1 , ϕ 1 2 , ϕ 1 3 , ϕ 1 4 \phi_1^1,\phi_1^2,\phi_1^3,\phi_1^4 ϕ11,ϕ12,ϕ13,ϕ14

然后将特征与一个T变量输入到解码器4中,它会生成一个小分辨率的光流,以及一个中间特征。光流用于将编码器输出的特征进行backward warp,更好地对齐目标帧的特征。解码器2和3会进行类似的计算,而解码器1的输出会稍微有所不同。解码器1会输出光流,一个mask (图中的M),以及一个3通道的residual(图中的R)。光流用于backward warp输入帧,得到两帧warped frames,mask则用于融合这两张warped frames,得到一个中间帧。最后,residual会对中间帧进行细节补充,得到预测结果。

T变量是用于指示时间戳的,它是一个单通道的,尺寸与输入图像相同的的特征图,数值全部为t,代表要预测的是t时刻的帧,当它是0.5时表示中间帧。

解码器中间预测的三个光流经过上采样后,会与使用预训练模型输出的光流进行损失计算。

解码器得到的三个中间特征 ϕ ^ t 3 , ϕ ^ t 2 , ϕ ^ t 1 \hat\phi_t^3,\hat\phi_t^2,\hat\phi_t^1 ϕ^t3,ϕ^t2,ϕ^t1,会与目标帧得到的特征 ϕ t 3 , ϕ t 2 , ϕ t 1 \phi_t^3,\phi_t^2,\phi_t^1 ϕt3,ϕt2,ϕt1进行损失计算。

Decoder的主要结构如下图所示。它主要有卷积块,resdual结构,以及一个反卷积组成。
在这里插入图片描述

损失函数

  1. 图像重建损失:此损失函数为插帧任务的基本损失函数,目的是为了使生成的图像符合目标中间帧
    在这里插入图片描述

  2. 面向任务的光流蒸馏损失(Task-Oriented Flow Distillation Loss):此损失函数通过调整每个像素位置的鲁棒性值来提供更好的面向插帧任务的中间光流监督信息。给定一个现成光流网络的预测结果作为代理标签,我们可以通过公式
    在这里插入图片描述

  3. 特征空间几何一致性损失(Feature Space Geometry Consistency Loss):此损失函数用来监督中间帧特征与Ground Truth中间帧特征具有一致的场景几何布局,从而提高目标帧合成质量

    在这里插入图片描述

模型表现

作者将本文的方法与其他一些方法在多个数据集上进行了指标对比,如下表所示
在这里插入图片描述

可以看到,在多个数据集上该模型都取得了不错的结果。对于普通版本的模型(表中的IFRNet),它的推理时间非常快,指标表现也不错,而small版本的模型则取得了最快的推理时间以及第二小的模型参数量。

large版本模型通过提升参数量,也取得了非常好的指标结果。

总的来说,该模型在保存较好的行性能的同时,实现了非常快的推理速度,

代码复现

核心模型代码

网络模型的定义在models/IFRNet.py 中,具体的解释参考视频讲解。

训练

数据集预处理

首先需要下载好Vimeo90K数据集。

根据项目的README文件,首先需要generate_flow.py生成Vimeo90K数据集的光流。

需要先在generate_flow.py中的第10行修改数据集所在目录,例如我修改成了

vimeo90k_dir = r'E:\Workspace\Datasets\vimeo_triplet'

然后在终端运行该generate_flow.py文件去生成光流。由于数据较多,这一步需要等待较长的时间。

python generate_flow.py

修改训练代码

具体参考视频讲解,以下给出一些关键点。

首先修改train_vimeo90k.py中读取数据集的代码,在train函数中修改数据集读取的地址:

# 53 行
dataset_train = Vimeo90K_Train_Dataset(dataset_dir=r'E:\Workspace\Datasets\vimeo_triplet', augment=True)

# 59 行
dataset_val = Vimeo90K_Test_Dataset(dataset_dir=r'E:\Workspace\Datasets\vimeo_triplet')

其次由于在windows上运行,需要修改分布式后端

# 159行
dist.init_process_group(backend='gloo', world_size=args.world_size)

其他的一些路径相关错误修改:

# 177 行
args.log_path = args.log_path + '\\' + args.model_name


由于windows目录名称限制,需要改一下日志的文件夹名称,把冒号改成其他的合法字符,如下:

# 37行原始
# log_path = os.path.join(args.log_path, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))

# 修改后
log_path = os.path.join(args.log_path, time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime()))

最后,运行train_vimeo90k.py就可以训练。由于我用的是单机器单卡,将分布式的一些参数修改如下后运行

python -m torch.distributed.launch --nproc_per_node=1 train_vimeo90k.py --world_size 1 --model_name 'IFRNet' --epochs 300 --batch_size 8 --lr_start 1e-4 --lr_end 1e-5

推理插帧

使用 demo_2x.py 指定输入图像和模型权重文件后运行即可。

模型修改

在图像生成任务中,高质量的特征对于提高生成质量是非常重要的,视频插帧中也需要。

在原始的IFRNet中,作者为了达到较高的速度,所以使用了较为简单的特征提取网络。如果想基于该网络进行改进,提高生成质量,最直接的一个思路就是修改它的特征提取器(文中的编码器)。

这里我作为示例,使用了一个resnet50来代替原来简单的编码器。通过修改Resnet中几层的输出通道,对齐原网络的需求就能完成。

核心定义代码如下

class ResNet_feature(nn.Module):

    def __init__(self, block, num_block):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        self.conv2_x = self._make_layer(block, 32, num_block[0], 2)
        self.conv3_x = self._make_layer(block, 48, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 72, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 96, num_block[3], 2)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)  # channel: 3 -> 64
        output = self.conv2_x(output)
        f1 = output
        output = self.conv3_x(output)
        f2 = output
        output = self.conv4_x(output)
        f3 = output
        output = self.conv5_x(output)
        f4 = output

        return f1, f2, f3, f4


def resnet50_feature():
    """ return a ResNet 50 object
    """
    return ResNet_feature(BottleNeck, [3, 4, 6, 3])

修改后在IFRNet.py中修改Resnet的定义后可以使用,如下:

class Model(nn.Module):
    def __init__(self, local_rank=-1, lr=1e-4):
        super(Model, self).__init__()
        self.encoder = resnet50_feature()  # 使用resnet50作为特征提取器
        self.decoder4 = Decoder4()
        self.decoder3 = Decoder3()
        self.decoder2 = Decoder2()
        self.decoder1 = Decoder1()
        self.l1_loss = Charbonnier_L1()
        self.tr_loss = Ternary(7)
        self.rb_loss = Charbonnier_Ada()
        self.gc_loss = Geometry(3)

最后在train_vimeo90k.py的修改使用的网络即可。

# 原文件170行
if args.model_name == 'IFRNet':
    from models.IFRNet_resnet import Model

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值