上篇文章《super slomo介绍》已经对slomo进行大概的介绍。
本章对该源码进行一个简单分析,由于本人笔记本显存不够,因此对该源码进行部分修改,可以减少显存占用和训练速度的提升。并且后续使用C++来调用训练好的模型并实现双向光流和中间帧合成。
源码地址:https://github.com/avinashpaliwal/Super-SloMo,官方WIKI有具体的安装过程,不在做介绍。
源码通过torch来实现基于U-NET的神经网络结构和光流算法。其中Model.py提供了训练模型的U-NET体系的建立。
train.py提供了训练相关功能。通过dataloader.SuperSloMo()加载训练集和验证集。训练集和验证集通过create_dataset.py提前生成,说明文档有相关生成命令。
for epoch in range(dict1[‘epoch’] + 1,args.epochs)这里进行整个训练的循环控制。训练过程很简单,首先从训练集获取三帧图frame0、frameT和frame1,其中frame0和frame1用来训练,frameT用来验证。然后把frame0和frame1通过torch.cat进行合并。放入到flowComp网络得到输出。再把输出拆分成F_0_1和F_1_0,后续用来计算光流残差和中间帧合成。最后进行loss计算并反向传播。
为了快速看到训练结果减少训练时间和GPU显存的限制,就在源码基础上做了一些修改,首先把flowComp = model.UNet(6,4)改成flowComp = model.UNet(2,4)这里把24位真彩图改成8位灰度图,同理改ArbTimeFlowIntrp = model.Unet(12, 5)。
MSE_LossFn调用vgg16对Ft_p和IFrame进行MSE计算loss,由于前面改动这里也不能使用,屏蔽掉,改loss=204*recnLoss+102*warpLoss+loss_smooth。
修改model.py中forward:直接跳过s3、s4、s5网络层,直接x=self.up5(s2,s1)。
构造中__init__():改self.conv1=nn.Conv2d(inChannels,4,7,stride=1,padding=3)
Self.conv2=nn.Conv2d(4,4,7,stride=1,padding=3)
Self.down1=down(4,8,5)
Self.up5=up(8,4)
Self.conv3=nn.Conv2d(4,outChannels,3,stride=1,padding=1)
最后修改输出训练模型:原来通过torch.save保存为dict1的集合,现在改成对flowComp和ArbTimeFlowIntrp进行输出。
修改后的代码虽然总体效果不如原来,但速度有了非常大的提升,对GPU的要求也降到非常低,适合用来测试。
下一章会把输出的训练模型转换成torch模型便于C++通过torch的调用。
修改后的代码:https://download.csdn.net/download/u011736517/12558294