#Copyright(c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint
from mmedit.models.common import(PixelShufflePack, ResidualBlockNoBN,
flow_warp, make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger
@BACKBONES.register_module()
class BasicVSRNet(nn.Module):"""BasicVSR network structure for video super-resolution.
Support only x4 upsampling.仅支持x4上采样
Paper:
BasicVSR: The Search for Essential Components in Video Super-Resolution
and Beyond, CVPR,2021
Args:mid_channels(int): Channel number of the intermediate features.
Default:64. 中间特征的通道数
num_blocks(int): Number of residual blocks in each propagation branch.
Default:30. 每个传播分支的残差块数
spynet_pretrained(str): Pre-trained model path of SPyNet.
Default: None."""
def __init__(self, mid_channels=64, num_blocks=30, spynet_pretrained=None):super().__init__()
self.mid_channels = mid_channels
#opticalflow network for feature alignment
self.spynet =SPyNet(pretrained=spynet_pretrained)#propagationbranches
self.backward_resblocks =ResidualBlocksWithInputConv(
mid_channels +3, mid_channels, num_blocks)
self.forward_resblocks =ResidualBlocksWithInputConv(
mid_channels +3, mid_channels, num_blocks)#upsample
self.fusion = nn.Conv2d(
mid_channels *2, mid_channels,1,1,0, bias=True)#PixelShufflePack(in_channels, out_channels, scale_factor, upsample_kernel)
self.upsample1 =PixelShufflePack(
mid_channels, mid_channels,2, upsample_kernel=3)
self.upsample2 =PixelShufflePack(
mid_channels,64,2, upsample_kernel=3)
self.conv_hr = nn.Conv2d(64,64,3,1,1)
self.conv_last = nn.Conv2d(64,3,3,1,1)
self.img_upsample = nn.Upsample(
scale_factor=4, mode='bilinear', align_corners=False)#activationfunction
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def check_if_mirror_extended(self, lrs):"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th(i=0,..., t-1) frame is equal to the(t-1-i)-th frame.
Args:lrs(tensor): Input LR images with shape(n, t, c, h, w)"""
self.is_mirror_extended = False
if lrs.size(1)%2==0:
lrs_1, lrs_2 = torch.chunk(lrs,2, dim=1)if torch.norm(lrs_1 - lrs_2.flip(1))==0:
self.is_mirror_extended = True
def compute_flow(self, lrs):"""Compute optical flow using SPyNet for feature warping.
Note that if the input is an mirror-extended sequence,'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:lrs(tensor): Input LR images with shape(n, t, c, h, w)
Return:tuple(Tensor): Optical flow.'flows_forward' corresponds to the
flows used for forward-time propagation(current to previous).'flows_backward' corresponds to the flows used for
backward-time propagation(current to next)."""
n, t, c, h, w = lrs.size()
lrs_1 = lrs[:,:-1,:,:,:].reshape(-1, c, h, w)
lrs_2 = lrs[:,1:,:,:,:].reshape(-1, c, h, w)
flows_backward = self.spynet(lrs_1, lrs_2).view(n, t -1,2, h, w)if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
flows_forward = None
else:
flows_forward = self.spynet(lrs_2, lrs_1).view(n, t -1,2, h, w)return flows_forward, flows_backward
def forward(self, lrs):"""Forward function for BasicVSR.
Args:lrs(Tensor): Input LR sequence with shape(n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape(n, t, c,4h,4w)."""
#print(lrs.size()) #torch.Size([1,14,3,64,112])
n, t, c, h, w = lrs.size()
assert h >=64 and w >=64,('The height and width of inputs should be at least 64, '
f'but got {h} and {w}.')#checkwhether the input is an extended sequence
self.check_if_mirror_extended(lrs)#computeoptical flow
flows_forward, flows_backward = self.compute_flow(lrs)#backward-time propgation
outputs =[]
feat_prop = lrs.new_zeros(n, self.mid_channels, h, w)for i in range(t -1,-1,-1):#反向循环输出
if i < t -1: # no warping required for the last timestep
flow = flows_backward[:, i,:,:,:]
feat_prop =flow_warp(feat_prop, flow.permute(0,2,3,1))
feat_prop = torch.cat([lrs[:, i,:,:,:], feat_prop], dim=1) #[lrs[:, i,:,:,:]表示当前帧
feat_prop = self.backward_resblocks(feat_prop)
outputs.append(feat_prop)
outputs = outputs[::-1] #将outputs里面的feat_prop反向
#forward-time propagation and upsampling
feat_prop = torch.zeros_like(feat_prop)for i in range(0, t):#正向
lr_curr = lrs[:, i,:,:,:] #当前帧
if i >0: # no warping required for the first timestep
if flows_forward is not None:
flow = flows_forward[:, i -1,:,:,:]else:#如果flows_forward为None,也就是说输入图像序列为镜像序列
flow = flows_backward[:,-i,:,:,:]
feat_prop =flow_warp(feat_prop, flow.permute(0,2,3,1))
feat_prop = torch.cat([lr_curr, feat_prop], dim=1)
feat_prop = self.forward_resblocks(feat_prop)#upsamplinggiven the backward and forward features
out = torch.cat([outputs[i], feat_prop], dim=1)#把正向和反向传播的feat_prop,cat起来
out = self.lrelu(self.fusion(out))
out = self.lrelu(self.upsample1(out))
out = self.lrelu(self.upsample2(out))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = self.img_upsample(lr_curr)
out += base
outputs[i]= out
return torch.stack(outputs, dim=1)#沿着新维度dim=1连接张量,此处是指(n,t,c,w,h)中的t
def init_weights(self, pretrained=None, strict=True):"""Init weights for models.
Args:pretrained(str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults: None.strict(boo, optional): Whether strictly load the pretrained model.
Defaults to True."""
ifisinstance(pretrained, str):
logger =get_root_logger()load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is not None:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
class ResidualBlocksWithInputConv(nn.Module):"""Residual blocks with a convolution in front.
Args:in_channels(int): Number of input channels of the first conv.out_channels(int): Number of channels of the residual blocks.
Default:64.num_blocks(int): Number of residual blocks. Default:30."""
def __init__(self, in_channels, out_channels=64, num_blocks=30):super().__init__()
main =[]#aconvolution used to match the channels of the residual blocks
main.append(nn.Conv2d(in_channels, out_channels,3,1,1, bias=True))
main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))#residualblocks
main.append(make_layer(
ResidualBlockNoBN, num_blocks, mid_channels=out_channels))
self.main = nn.Sequential(*main)
def forward(self, feat):"""
Forward function for ResidualBlocksWithInputConv.
Args:feat(Tensor): Input feature with shape(n, in_channels, h, w)
Returns:
Tensor: Output feature with shape(n, out_channels, h, w)"""
return self.main(feat)
class SPyNet(nn.Module):"""SPyNet network structure.
The difference to the SPyNet in [tof.py] is that
1. more SPyNetBasicModule is used in this version, and
2. no batch normalization is used in this version.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network, CVPR,2017
Args:pretrained(str): path for pre-trained SPyNet. Default: None."""
def __init__(self, pretrained):super().__init__()
self.basic_module = nn.ModuleList([SPyNetBasicModule()for _ in range(6)])ifisinstance(pretrained, str):
logger =get_root_logger()load_checkpoint(self, pretrained, strict=True, logger=logger)
elif pretrained is not None:
raise TypeError('[pretrained] should be str or None, '
f'but got {type(pretrained)}.')
self.register_buffer('mean',
torch.Tensor([0.485,0.456,0.406]).view(1,3,1,1))
self.register_buffer('std',
torch.Tensor([0.229,0.224,0.225]).view(1,3,1,1))
def compute_flow(self, ref, supp):"""Compute flow from ref to supp.
Note that in this function, the images are already resized to a
multiple of 32.
Args:ref(Tensor): Reference image with shape of(n,3, h, w).supp(Tensor): Supporting image with shape of(n,3, h, w).
Returns:
Tensor: Estimated optical flow:(n,2, h, w)."""
n, _, h, w = ref.size()#normalizethe input images
ref =[(ref - self.mean)/ self.std]
supp =[(supp - self.mean)/ self.std]#generatedownsampled framesfor level in range(5):
ref.append(
F.avg_pool2d(
input=ref[-1],
kernel_size=2,
stride=2,
count_include_pad=False))
supp.append(
F.avg_pool2d(
input=supp[-1],
kernel_size=2,
stride=2,
count_include_pad=False))
ref = ref[::-1]
supp = supp[::-1]#flowcomputation
flow = ref[0].new_zeros(n,2, h // 32, w // 32)for level in range(len(ref)):if level ==0:
flow_up = flow
else:
flow_up = F.interpolate(
input=flow,
scale_factor=2,
mode='bilinear',
align_corners=True)*2.0#addthe residue to the upsampled flow
flow = flow_up + self.basic_module[level](
torch.cat([
ref[level],flow_warp(
supp[level],
flow_up.permute(0,2,3,1),
padding_mode='border'), flow_up
],1))return flow
def forward(self, ref, supp):"""Forward function of SPyNet.
This function computes the optical flow from ref to supp.
Args:ref(Tensor): Reference image with shape of(n,3, h, w).supp(Tensor): Supporting image with shape of(n,3, h, w).
Returns:
Tensor: Estimated optical flow:(n,2, h, w)."""
#upsizeto a multiple of 32
h, w = ref.shape[2:4]
w_up = w if(w %32)==0else32*(w // 32 + 1)
h_up = h if(h %32)==0else32*(h // 32 + 1)
ref = F.interpolate(
input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False)
supp = F.interpolate(
input=supp,
size=(h_up, w_up),
mode='bilinear',
align_corners=False)#computeflow, and resize back to the original resolution
flow = F.interpolate(
input=self.compute_flow(ref, supp),
size=(h, w),
mode='bilinear',
align_corners=False)#adjustthe flow values
flow[:,0,:,:]*=float(w)/float(w_up)
flow[:,1,:,:]*=float(h)/float(h_up)return flow
class SPyNetBasicModule(nn.Module):"""Basic Module for SPyNet.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network, CVPR,2017"""
def __init__(self):super().__init__()
self.basic_module = nn.Sequential(ConvModule(
in_channels=8,
out_channels=32,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),ConvModule(
in_channels=32,
out_channels=64,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),ConvModule(
in_channels=64,
out_channels=32,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),ConvModule(
in_channels=32,
out_channels=16,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),ConvModule(
in_channels=16,
out_channels=2,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=None))
def forward(self, tensor_input):"""
Args:tensor_input(Tensor): Input tensor with shape(b,8, h, w).8 channels contain:[reference image(3), neighbor image(3), initial flow(2)].
Returns:
Tensor: Refined flow with shape(b,2, h, w)"""
return self.basic_module(tensor_input)