二、slomo部分源码分析

上篇文章《super slomo介绍》已经对slomo进行大概的介绍。

         本章对该源码进行一个简单分析,由于本人笔记本显存不够,因此对该源码进行部分修改,可以减少显存占用和训练速度的提升。并且后续使用C++来调用训练好的模型并实现双向光流和中间帧合成。

         源码地址:https://github.com/avinashpaliwal/Super-SloMo,官方WIKI有具体的安装过程,不在做介绍。

         源码通过torch来实现基于U-NET的神经网络结构和光流算法。其中Model.py提供了训练模型的U-NET体系的建立。

train.py提供了训练相关功能。通过dataloader.SuperSloMo()加载训练集和验证集。训练集和验证集通过create_dataset.py提前生成,说明文档有相关生成命令。

         for epoch in range(dict1[‘epoch’] + 1,args.epochs)这里进行整个训练的循环控制。训练过程很简单,首先从训练集获取三帧图frame0、frameT和frame1,其中frame0和frame1用来训练,frameT用来验证。然后把frame0和frame1通过torch.cat进行合并。放入到flowComp网络得到输出。再把输出拆分成F_0_1和F_1_0,后续用来计算光流残差和中间帧合成。最后进行loss计算并反向传播。

         为了快速看到训练结果减少训练时间和GPU显存的限制,就在源码基础上做了一些修改,首先把flowComp = model.UNet(6,4)改成flowComp = model.UNet(2,4)这里把24位真彩图改成8位灰度图,同理改ArbTimeFlowIntrp  = model.Unet(12, 5)。

         MSE_LossFn调用vgg16对Ft_p和IFrame进行MSE计算loss,由于前面改动这里也不能使用,屏蔽掉,改loss=204*recnLoss+102*warpLoss+loss_smooth。

         修改model.py中forward:直接跳过s3、s4、s5网络层,直接x=self.up5(s2,s1)。

         构造中__init__():改self.conv1=nn.Conv2d(inChannels,4,7,stride=1,padding=3)

         Self.conv2=nn.Conv2d(4,4,7,stride=1,padding=3)

         Self.down1=down(4,8,5)

         Self.up5=up(8,4)

         Self.conv3=nn.Conv2d(4,outChannels,3,stride=1,padding=1)

最后修改输出训练模型:原来通过torch.save保存为dict1的集合,现在改成对flowComp和ArbTimeFlowIntrp进行输出。

         修改后的代码虽然总体效果不如原来,但速度有了非常大的提升,对GPU的要求也降到非常低,适合用来测试。

         下一章会把输出的训练模型转换成torch模型便于C++通过torch的调用。

修改后的代码:https://download.csdn.net/download/u011736517/12558294

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值