超分之EDVR-代码详解

EDVR: Video Restoration with Enhanced Deformable Convolutional Networks

Paper:https://openaccess.thecvf.com/content_CVPRW_2019/papers/NTIRE/Wang_EDVR_Video_Restoration_With_Enhanced_Deformable_Convolutional_Networks_CVPRW_2019_paper.pdf

代码:https://github.com/XPixelGroup/BasicSR



在我最近的研究工作中,用到了配准(或对齐)网络。在很多视频超分任务中,会用到对齐模块。近期对EDVR这篇文章进行了学习,网上已经有很多优秀的解读(如参考博客(点击跳转)),我在阅读之后也收获颇多,向作者表示感谢。


但是发现,目前应该是没有对EDVR代码的详细讲解,于是记录在此,与大家一起分享。


PS:
1.由于笔者所研究的方向并不是视频超分,因此很多术语可能表达的不是很准确,还请大家见谅并多多指正。
2.具体的理论内容讲解,见上面提到的参考博客,写的真的很好。本篇博客主要是讲代码
3.由于整个模型的代码实现很长,所以不可能把所有的小点都讲到。但是我基本把代码都贴出来了,而且代码基本都有详细的注释。


1.EDVR的网络模型

VSR的pipelines通常由四个部分组成,即特征提取、对齐、融合和重建。
在这里插入图片描述

下面这段代码是总的EDVR类,可见确实是按照【特征提取、对齐、融合和重建】的顺序进行的。当然,还有一些细节(如是否使用去模糊模块等)。对于整个网络的每一个组成部分,我们都会进行具体的讲解。

import torch
from torch import nn as nn
from torch.nn import functional as F
from torchsummary import summary
#from basicsr.utils.registry import ARCH_REGISTRY
from basicsr.archs.arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
#通过pip install basicsr安装basicsr这个库


class EDVR(nn.Module):
    """EDVR network structure for video super-resolution.
    Now only support X4 upsampling factor.
    """
    def __init__(self,
                 num_in_ch=3,       #num_in_ch (int): Channel number of input image. Default: 3.
                 num_out_ch=3,      #num_out_ch (int): Channel number of output image. Default: 3.
                 num_feat=64,       #num_feat (int): Channel number of intermediate features. Default: 64.
                 num_frame=5,       #num_frame (int): Number of input frames. Default: 5.
                 deformable_groups=8,  #deformable_groups (int): Deformable groups. Defaults: 8.
                 num_extract_block=5,  #num_extract_block (int): Number of blocks for feature extraction.Default: 5.
                 num_reconstruct_block=10,  #num_reconstruct_block (int): Number of blocks for reconstruction.Default: 10.
                 center_frame_idx=None,     #center_frame_idx (int): The index of center frame. Frame counting from 0. Default: Middle of input frames.
                 hr_in=False,               #hr_in (bool): Whether the input has high resolution. Default: False.
                 with_predeblur=False,      #with_predeblur (bool): Whether has predeblur module.Default: False.
                 with_tsa=True):            #with_tsa (bool): Whether has TSA module. Default: True.
        super(EDVR, self).__init__()
        if center_frame_idx is None:
            self.center_frame_idx = num_frame // 2    #选取中间帧作为参考帧
        else:
            self.center_frame_idx = center_frame_idx
        self.hr_in = hr_in
        self.with_predeblur = with_predeblur
        self.with_tsa = with_tsa

        # extract features for each frame
        if self.with_predeblur:
            self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)   #去模糊
            self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
        else:
            self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)  #不用去模糊.shape不变,通道数变为num_feat

        # extract pyramid features
        #通过5个残差块来提取特征
        #这里算是L1层
        self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
        #L2层
        self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
        self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        #L3层
        self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
        self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)

        # pcd and tsa module
        self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
        if self.with_tsa:
            self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
        else:
            self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)

        # reconstruction
        self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
        # upsample
        self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
        self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
        self.pixel_shuffle = nn.PixelShuffle(2)
        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

    def forward(self, x):
        b, t, c, h, w = x.size()
        if self.hr_in:
            assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
        else:
            assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')

        x_center = x[:, self.center_frame_idx, :, :, :].contiguous()    #提取中间帧(参考帧)

        # extract features for each frame
        # L1
        if self.with_predeblur:
            feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
            if self.hr_in:
                h, w = h // 4, w // 4
        else:
            feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))

        feat_l1 = self.feature_extraction(feat_l1)
        # L2
        feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
        feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
        # L3
        feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
        feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
        #可见,是*2倍的下采样
        feat_l1 = feat_l1.view(b, t, -1, h, w)
        feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
        feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)

        # PCD alignment
        ref_feat_l = [  # reference feature list,存放金字塔(L1,L2,L3)参考帧特征数据.一定要注意顺序
            feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
            feat_l3[:, self.center_frame_idx, :, :, :].clone()
        ]
        aligned_feat = []  #对齐后的特征list
        for i in range(t):
            #每一个i(即每一个支持帧),都会存取其L1,L2,L3的特征,并与 ref_feat_l一起送入对齐模块,实现了特征对齐
            nbr_feat_l = [  # neighboring feature list,存放金字塔所有帧特征数据
                feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
            ]
            aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))  #随着循环结束,存入了5帧的对齐特征
        aligned_feat = torch.stack(aligned_feat, dim=1)  #列表转为torch, (b, t, c, h, w)

        if not self.with_tsa:
            aligned_feat = aligned_feat.view(b, -1, h, w)
        feat = self.fusion(aligned_feat)

        out = self.reconstruction(feat)
        out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
        out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
        out = self.lrelu(self.conv_hr(out))
        out = self.conv_last(out)
        if self.hr_in:
            base = x_center
        else:
            base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
        out += base
        return out

再给出整个的网络模型:

net = EDVR()
print(net)
summary(net,[(5,3,128,128)],device="cpu")

运行结果如下:

EDVR(
  (conv_first): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (feature_extraction): Sequential(
    (0): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (1): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (2): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (3): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (4): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
  )
  (conv_l2_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_l2_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_l3_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_l3_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pcd_align): PCDAlignment(
    (offset_conv1): ModuleDict(
      (l3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (offset_conv2): ModuleDict(
      (l3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (offset_conv3): ModuleDict(
      (l2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (dcn_pack): ModuleDict(
      (l3): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (l2): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (l1): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (feat_conv): ModuleDict(
      (l2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (cas_offset_conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (cas_offset_conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (cas_dcnpack): DCNv2Pack(
      (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (upsample): Upsample(scale_factor=2.0, mode=bilinear)
    (lrelu): LeakyReLU(negative_slope=0.1, inplace=True)
  )
  (fusion): TSAFusion(
    (temporal_attn1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (temporal_attn2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (feat_fusion): Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1))
    (max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (avg_pool): AvgPool2d(kernel_size=3, stride=2, padding=1)
    (spatial_attn1): Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1))
    (spatial_attn2): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (spatial_attn3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (spatial_attn4): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (spatial_attn5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (spatial_attn_l1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (spatial_attn_l2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (spatial_attn_l3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (spatial_attn_add1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (spatial_attn_add2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (lrelu): LeakyReLU(negative_slope=0.1, inplace=True)
    (upsample): Upsample(scale_factor=2.0, mode=bilinear)
  )
  (reconstruction): Sequential(
    (0): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (1): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (2): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (3): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (4): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (5): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (6): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (7): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (8): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (9): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
  )
  (upconv1): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upconv2): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=2)
  (conv_hr): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_last): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (lrelu): LeakyReLU(negative_slope=0.1, inplace=True)
)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 64, 128, 128]        1,792
├─LeakyReLU: 1-2                         [-1, 64, 128, 128]        --
├─Sequential: 1-3                        [-1, 64, 128, 128]        --
|    └─ResidualBlockNoBN: 2-1            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-1                  [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-2                    [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-3                  [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-2            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-4                  [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-5                    [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-6                  [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-3            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-7                  [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-8                    [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-9                  [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-4            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-10                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-11                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-12                 [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-5            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-13                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-14                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-15                 [-1, 64, 128, 128]        36,928
├─Conv2d: 1-4                            [-1, 64, 64, 64]          36,928
├─LeakyReLU: 1-5                         [-1, 64, 64, 64]          --
├─Conv2d: 1-6                            [-1, 64, 64, 64]          36,928
├─LeakyReLU: 1-7                         [-1, 64, 64, 64]          --
├─Conv2d: 1-8                            [-1, 64, 32, 32]          36,928
├─LeakyReLU: 1-9                         [-1, 64, 32, 32]          --
├─Conv2d: 1-10                           [-1, 64, 32, 32]          36,928
├─LeakyReLU: 1-11                        [-1, 64, 32, 32]          --
├─PCDAlignment: 1-12                     [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-16                 [-1, 64, 32, 32]          73,792
|    └─LeakyReLU: 2-6                    [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-17                 [-1, 64, 32, 32]          36,928
|    └─LeakyReLU: 2-7                    [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-18              [-1, 64, 32, 32]          161,560
|    └─LeakyReLU: 2-8                    [-1, 64, 32, 32]          --
|    └─Upsample: 2-9                     [-1, 64, 64, 64]          --
|    └─Upsample: 2-10                    [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-19                 [-1, 64, 64, 64]          73,792
|    └─LeakyReLU: 2-11                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-20                 [-1, 64, 64, 64]          73,792
|    └─LeakyReLU: 2-12                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-21                 [-1, 64, 64, 64]          36,928
|    └─LeakyReLU: 2-13                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-22              [-1, 64, 64, 64]          161,560
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-23                 [-1, 64, 64, 64]          73,792
|    └─LeakyReLU: 2-14                   [-1, 64, 64, 64]          --
|    └─Upsample: 2-15                    [-1, 64, 128, 128]        --
|    └─Upsample: 2-16                    [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-24                 [-1, 64, 128, 128]        73,792
|    └─LeakyReLU: 2-17                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-25                 [-1, 64, 128, 128]        73,792
|    └─LeakyReLU: 2-18                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-26                 [-1, 64, 128, 128]        36,928
|    └─LeakyReLU: 2-19                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-27              [-1, 64, 128, 128]        161,560
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-28                 [-1, 64, 128, 128]        73,792
|    └─Conv2d: 2-20                      [-1, 64, 128, 128]        73,792
|    └─LeakyReLU: 2-21                   [-1, 64, 128, 128]        --
|    └─Conv2d: 2-22                      [-1, 64, 128, 128]        36,928
|    └─LeakyReLU: 2-23                   [-1, 64, 128, 128]        --
|    └─DCNv2Pack: 2-24                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-29                 [-1, 216, 128, 128]       124,632
|    └─LeakyReLU: 2-25                   [-1, 64, 128, 128]        --
├─PCDAlignment: 1-13                     [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-30                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-26                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-31                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-27                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-32              [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-28                   [-1, 64, 32, 32]          --
|    └─Upsample: 2-29                    [-1, 64, 64, 64]          --
|    └─Upsample: 2-30                    [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-33                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-31                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-34                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-32                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-35                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-33                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-36              [-1, 64, 64, 64]          (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-37                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-34                   [-1, 64, 64, 64]          --
|    └─Upsample: 2-35                    [-1, 64, 128, 128]        --
|    └─Upsample: 2-36                    [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-38                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-37                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-39                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-38                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-40                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-39                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-41              [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-42                 [-1, 64, 128, 128]        (recursive)
|    └─Conv2d: 2-40                      [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-41                   [-1, 64, 128, 128]        --
|    └─Conv2d: 2-42                      [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-43                   [-1, 64, 128, 128]        --
|    └─DCNv2Pack: 2-44                   [-1, 64, 128, 128]        (recursive)
|    |    └─Conv2d: 3-43                 [-1, 216, 128, 128]       (recursive)
|    └─LeakyReLU: 2-45                   [-1, 64, 128, 128]        --
├─PCDAlignment: 1-14                     [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-44                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-46                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-45                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-47                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-46              [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-48                   [-1, 64, 32, 32]          --
|    └─Upsample: 2-49                    [-1, 64, 64, 64]          --
|    └─Upsample: 2-50                    [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-47                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-51                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-48                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-52                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-49                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-53                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-50              [-1, 64, 64, 64]          (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-51                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-54                   [-1, 64, 64, 64]          --
|    └─Upsample: 2-55                    [-1, 64, 128, 128]        --
|    └─Upsample: 2-56                    [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-52                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-57                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-53                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-58                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-54                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-59                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-55              [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-56                 [-1, 64, 128, 128]        (recursive)
|    └─Conv2d: 2-60                      [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-61                   [-1, 64, 128, 128]        --
|    └─Conv2d: 2-62                      [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-63                   [-1, 64, 128, 128]        --
|    └─DCNv2Pack: 2-64                   [-1, 64, 128, 128]        (recursive)
|    |    └─Conv2d: 3-57                 [-1, 216, 128, 128]       (recursive)
|    └─LeakyReLU: 2-65                   [-1, 64, 128, 128]        --
├─PCDAlignment: 1-15                     [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-58                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-66                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-59                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-67                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-60              [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-68                   [-1, 64, 32, 32]          --
|    └─Upsample: 2-69                    [-1, 64, 64, 64]          --
|    └─Upsample: 2-70                    [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-61                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-71                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-62                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-72                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-63                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-73                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-64              [-1, 64, 64, 64]          (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-65                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-74                   [-1, 64, 64, 64]          --
|    └─Upsample: 2-75                    [-1, 64, 128, 128]        --
|    └─Upsample: 2-76                    [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-66                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-77                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-67                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-78                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-68                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-79                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-69              [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-70                 [-1, 64, 128, 128]        (recursive)
|    └─Conv2d: 2-80                      [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-81                   [-1, 64, 128, 128]        --
|    └─Conv2d: 2-82                      [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-83                   [-1, 64, 128, 128]        --
|    └─DCNv2Pack: 2-84                   [-1, 64, 128, 128]        (recursive)
|    |    └─Conv2d: 3-71                 [-1, 216, 128, 128]       (recursive)
|    └─LeakyReLU: 2-85                   [-1, 64, 128, 128]        --
├─PCDAlignment: 1-16                     [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-72                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-86                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-73                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-87                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-74              [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-88                   [-1, 64, 32, 32]          --
|    └─Upsample: 2-89                    [-1, 64, 64, 64]          --
|    └─Upsample: 2-90                    [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-75                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-91                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-76                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-92                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-77                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-93                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-78              [-1, 64, 64, 64]          (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-79                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-94                   [-1, 64, 64, 64]          --
|    └─Upsample: 2-95                    [-1, 64, 128, 128]        --
|    └─Upsample: 2-96                    [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-80                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-97                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-81                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-98                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-82                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-99                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-83              [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-84                 [-1, 64, 128, 128]        (recursive)
|    └─Conv2d: 2-100                     [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-101                  [-1, 64, 128, 128]        --
|    └─Conv2d: 2-102                     [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-103                  [-1, 64, 128, 128]        --
|    └─DCNv2Pack: 2-104                  [-1, 64, 128, 128]        (recursive)
|    |    └─Conv2d: 3-85                 [-1, 216, 128, 128]       (recursive)
|    └─LeakyReLU: 2-105                  [-1, 64, 128, 128]        --
├─TSAFusion: 1-17                        [-1, 64, 128, 128]        --
|    └─Conv2d: 2-106                     [-1, 64, 128, 128]        36,928
|    └─Conv2d: 2-107                     [-1, 64, 128, 128]        36,928
|    └─Conv2d: 2-108                     [-1, 64, 128, 128]        20,544
|    └─LeakyReLU: 2-109                  [-1, 64, 128, 128]        --
|    └─Conv2d: 2-110                     [-1, 64, 128, 128]        20,544
|    └─LeakyReLU: 2-111                  [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-112                  [-1, 64, 64, 64]          --
|    └─AvgPool2d: 2-113                  [-1, 64, 64, 64]          --
|    └─Conv2d: 2-114                     [-1, 64, 64, 64]          8,256
|    └─LeakyReLU: 2-115                  [-1, 64, 64, 64]          --
|    └─Conv2d: 2-116                     [-1, 64, 64, 64]          4,160
|    └─LeakyReLU: 2-117                  [-1, 64, 64, 64]          --
|    └─MaxPool2d: 2-118                  [-1, 64, 32, 32]          --
|    └─AvgPool2d: 2-119                  [-1, 64, 32, 32]          --
|    └─Conv2d: 2-120                     [-1, 64, 32, 32]          73,792
|    └─LeakyReLU: 2-121                  [-1, 64, 32, 32]          --
|    └─Conv2d: 2-122                     [-1, 64, 32, 32]          36,928
|    └─LeakyReLU: 2-123                  [-1, 64, 32, 32]          --
|    └─Upsample: 2-124                   [-1, 64, 64, 64]          --
|    └─Conv2d: 2-125                     [-1, 64, 64, 64]          36,928
|    └─LeakyReLU: 2-126                  [-1, 64, 64, 64]          --
|    └─Conv2d: 2-127                     [-1, 64, 64, 64]          4,160
|    └─LeakyReLU: 2-128                  [-1, 64, 64, 64]          --
|    └─Upsample: 2-129                   [-1, 64, 128, 128]        --
|    └─Conv2d: 2-130                     [-1, 64, 128, 128]        36,928
|    └─Conv2d: 2-131                     [-1, 64, 128, 128]        4,160
|    └─LeakyReLU: 2-132                  [-1, 64, 128, 128]        --
|    └─Conv2d: 2-133                     [-1, 64, 128, 128]        4,160
├─Sequential: 1-18                       [-1, 64, 128, 128]        --
|    └─ResidualBlockNoBN: 2-134          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-86                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-87                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-88                 [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-135          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-89                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-90                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-91                 [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-136          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-92                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-93                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-94                 [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-137          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-95                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-96                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-97                 [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-138          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-98                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-99                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-100                [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-139          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-101                [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-102                  [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-103                [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-140          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-104                [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-105                  [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-106                [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-141          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-107                [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-108                  [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-109                [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-142          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-110                [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-111                  [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-112                [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-143          [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-113                [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-114                  [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-115                [-1, 64, 128, 128]        36,928
├─Conv2d: 1-19                           [-1, 256, 128, 128]       147,712
├─PixelShuffle: 1-20                     [-1, 64, 256, 256]        --
├─LeakyReLU: 1-21                        [-1, 64, 256, 256]        --
├─Conv2d: 1-22                           [-1, 256, 256, 256]       147,712
├─PixelShuffle: 1-23                     [-1, 64, 512, 512]        --
├─LeakyReLU: 1-24                        [-1, 64, 512, 512]        --
├─Conv2d: 1-25                           [-1, 64, 512, 512]        36,928
├─LeakyReLU: 1-26                        [-1, 64, 512, 512]        --
├─Conv2d: 1-27                           [-1, 3, 512, 512]         1,731
==========================================================================================
Total params: 3,263,203
Trainable params: 3,263,203
Non-trainable params: 0
Total mult-adds (G): 103.30
==========================================================================================
Input size (MB): 0.94
Forward/backward pass size (MB): 731.44
Params size (MB): 12.45
Estimated Total Size (MB): 744.82
==========================================================================================

本节(第一节)的讲解,是对总的EDVR这个类的代码进行讲解,由于代码确实很长,讲解起来的安排如下:
1.对于总的EDVR类,我会在每段代码的一开始写明这是在构造方法中,还是在forward函数中。
2.总的EDVR类会调用很多分模块的类,而各个模块的类也可能使用到更小更详细功能的类,放在第二节中讲。


1.1 输入

输入一组(5个)帧(这里就体现出了我不是研究这个方向的了,说的很直白hh,不知道这么说合不合适)。
下面的一段代码,对输入的图像进行了一次卷积操作,由图像得到feature map,shape不变,通道数由3变为了设定的num_feat,作者取64。这也算是对金字塔中L1层提取了初步特征。

####__init__构造方法
        # extract features for each frame
        if self.with_predeblur:
            self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)   #去模糊
            self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
        else:
            self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)  #不用去模糊.shape不变,通道数变为num_feat

这一块的网络结构如下:

(conv_first): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 64, 128, 128]        1,792

1.2 特征提取

下面首先要做的就是先把金字塔(L1,L2,L3)构造出来了。也就是说特征提取包括三层。
注意,从L1,L2,L3,我们认为是从底端到顶端。而不是从顶端到底端。
在这里插入图片描述
下面这段代码,便是提取三层特征

####__init__构造方法
 # extract pyramid features
        #通过5个残差块来提取特征
        #这里算是L1层
        self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
        #L2层
        self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
        self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        #L3层
        self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
        self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
 (feature_extraction): Sequential(
    (0): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (1): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (2): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (3): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (4): ResidualBlockNoBN(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
  )
  (conv_l2_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_l2_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_l3_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_l3_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
├─Sequential: 1-3                        [-1, 64, 128, 128]        --
|    └─ResidualBlockNoBN: 2-1            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-1                  [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-2                    [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-3                  [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-2            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-4                  [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-5                    [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-6                  [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-3            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-7                  [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-8                    [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-9                  [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-4            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-10                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-11                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-12                 [-1, 64, 128, 128]        36,928
|    └─ResidualBlockNoBN: 2-5            [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-13                 [-1, 64, 128, 128]        36,928
|    |    └─ReLU: 3-14                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-15                 [-1, 64, 128, 128]        36,928
├─Conv2d: 1-4                            [-1, 64, 64, 64]          36,928
├─LeakyReLU: 1-5                         [-1, 64, 64, 64]          --
├─Conv2d: 1-6                            [-1, 64, 64, 64]          36,928
├─LeakyReLU: 1-7                         [-1, 64, 64, 64]          --
├─Conv2d: 1-8                            [-1, 64, 32, 32]          36,928
├─LeakyReLU: 1-9                         [-1, 64, 32, 32]          --
├─Conv2d: 1-10                           [-1, 64, 32, 32]          36,928
├─LeakyReLU: 1-11                        [-1, 64, 32, 32]          --

可见,经过金字塔提取到多尺度特征,其通道数不变,但是shape逐次减半(128–>64–>32)。
另外,正如前面所说的,

 self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)

对于上面的这种多个残差块堆叠的网络,就不仔细讲了。

1.3 特征对齐

下面进行对齐操作,代码如下:

####__init__构造方法

# pcd and tsa module
        self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
####forward函数

# PCD alignment
        ref_feat_l = [  # reference feature list,存放金字塔(L1,L2,L3)参考帧特征数据.一定要注意顺序
            feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
            feat_l3[:, self.center_frame_idx, :, :, :].clone()
        ]
        aligned_feat = []  #对齐后的特征list
        for i in range(t):
            #每一个i(即每一个支持帧),都会存取其L1,L2,L3的特征,并与 ref_feat_l一起送入对齐模块,实现了特征对齐
            nbr_feat_l = [  # neighboring feature list,存放金字塔所有帧特征数据
                feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
            ]
            aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))  #随着循环结束,存入了5帧的对齐特征
        aligned_feat = torch.stack(aligned_feat, dim=1)  #列表转为torch, (b, t, c, h, w)

这一块的网络结构如下:

  (pcd_align): PCDAlignment(
    (offset_conv1): ModuleDict(
      (l3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (offset_conv2): ModuleDict(
      (l3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (offset_conv3): ModuleDict(
      (l2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (dcn_pack): ModuleDict(
      (l3): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (l2): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (l1): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (feat_conv): ModuleDict(
      (l2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (l1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (cas_offset_conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (cas_offset_conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (cas_dcnpack): DCNv2Pack(
      (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (upsample): Upsample(scale_factor=2.0, mode=bilinear)
    (lrelu): LeakyReLU(negative_slope=0.1, inplace=True)
  )
├─PCDAlignment: 1-12                     [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-16                 [-1, 64, 32, 32]          73,792
|    └─LeakyReLU: 2-6                    [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-17                 [-1, 64, 32, 32]          36,928
|    └─LeakyReLU: 2-7                    [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-18              [-1, 64, 32, 32]          161,560
|    └─LeakyReLU: 2-8                    [-1, 64, 32, 32]          --
|    └─Upsample: 2-9                     [-1, 64, 64, 64]          --
|    └─Upsample: 2-10                    [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-19                 [-1, 64, 64, 64]          73,792
|    └─LeakyReLU: 2-11                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-20                 [-1, 64, 64, 64]          73,792
|    └─LeakyReLU: 2-12                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-21                 [-1, 64, 64, 64]          36,928
|    └─LeakyReLU: 2-13                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-22              [-1, 64, 64, 64]          161,560
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-23                 [-1, 64, 64, 64]          73,792
|    └─LeakyReLU: 2-14                   [-1, 64, 64, 64]          --
|    └─Upsample: 2-15                    [-1, 64, 128, 128]        --
|    └─Upsample: 2-16                    [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-24                 [-1, 64, 128, 128]        73,792
|    └─LeakyReLU: 2-17                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-25                 [-1, 64, 128, 128]        73,792
|    └─LeakyReLU: 2-18                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-26                 [-1, 64, 128, 128]        36,928
|    └─LeakyReLU: 2-19                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-27              [-1, 64, 128, 128]        161,560
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-28                 [-1, 64, 128, 128]        73,792
|    └─Conv2d: 2-20                      [-1, 64, 128, 128]        73,792
|    └─LeakyReLU: 2-21                   [-1, 64, 128, 128]        --
|    └─Conv2d: 2-22                      [-1, 64, 128, 128]        36,928
|    └─LeakyReLU: 2-23                   [-1, 64, 128, 128]        --
|    └─DCNv2Pack: 2-24                   [-1, 64, 128, 128]        --
|    |    └─Conv2d: 3-29                 [-1, 216, 128, 128]       124,632
|    └─LeakyReLU: 2-25                   [-1, 64, 128, 128]        --
├─PCDAlignment: 1-13                     [-1, 64, 128, 128]        (recursive)
##省略
├─PCDAlignment: 1-14                     [-1, 64, 128, 128]        (recursive)
##省略
├─PCDAlignment: 1-15                     [-1, 64, 128, 128]        (recursive)
##省略
├─PCDAlignment: 1-16                     [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-72                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-86                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-73                 [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-87                   [-1, 64, 32, 32]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-74              [-1, 64, 32, 32]          (recursive)
|    └─LeakyReLU: 2-88                   [-1, 64, 32, 32]          --
|    └─Upsample: 2-89                    [-1, 64, 64, 64]          --
|    └─Upsample: 2-90                    [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-75                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-91                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-76                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-92                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-77                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-93                   [-1, 64, 64, 64]          --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-78              [-1, 64, 64, 64]          (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-79                 [-1, 64, 64, 64]          (recursive)
|    └─LeakyReLU: 2-94                   [-1, 64, 64, 64]          --
|    └─Upsample: 2-95                    [-1, 64, 128, 128]        --
|    └─Upsample: 2-96                    [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-80                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-97                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-81                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-98                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-82                 [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-99                   [-1, 64, 128, 128]        --
|    └─ModuleDict: 2                     []                        --
|    |    └─DCNv2Pack: 3-83              [-1, 64, 128, 128]        (recursive)
|    └─ModuleDict: 2                     []                        --
|    |    └─Conv2d: 3-84                 [-1, 64, 128, 128]        (recursive)
|    └─Conv2d: 2-100                     [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-101                  [-1, 64, 128, 128]        --
|    └─Conv2d: 2-102                     [-1, 64, 128, 128]        (recursive)
|    └─LeakyReLU: 2-103                  [-1, 64, 128, 128]        --
|    └─DCNv2Pack: 2-104                  [-1, 64, 128, 128]        (recursive)
|    |    └─Conv2d: 3-85                 [-1, 216, 128, 128]       (recursive)
|    └─LeakyReLU: 2-105                  [-1, 64, 128, 128]        --

可见一共用到了5个PCDAlignment模块。原因是有5帧。


1.4 特征融合

下面进行融合操作,代码如下:
待填坑


1.5 高分重建

下面进行融合操作,代码如下:
待填坑



2.各个小模块的代码实现

这里我们只介绍主要模块(对齐,融合,…)。其他的一些模块(如上面所说的5个残差块等)不再仔细介绍,相信大家都可以看懂。

2.1 对齐模块 PCDAlignment

下面,先给出总的PCDAlignment这个类的代码:

class PCDAlignment(nn.Module):
    """Alignment module using Pyramid, Cascading and Deformable convolution
    (PCD). It is used in EDVR.

    ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``

    Args:
        num_feat (int): Channel number of middle features. Default: 64.
        deformable_groups (int): Deformable groups. Defaults: 8.
    """

    def __init__(self, num_feat=64, deformable_groups=8):
        super(PCDAlignment, self).__init__()

        # Pyramid has three levels:
        # L3: level 3, 1/4 spatial size
        # L2: level 2, 1/2 spatial size
        # L1: level 1, original spatial size
        self.offset_conv1 = nn.ModuleDict()
        self.offset_conv2 = nn.ModuleDict()
        self.offset_conv3 = nn.ModuleDict()
        self.dcn_pack = nn.ModuleDict()
        self.feat_conv = nn.ModuleDict()

        # Pyramids
        for i in range(3, 0, -1):   #倒着走,从金字塔的顶端往回走。3,2,1
            level = f'l{i}'         #l3,l2,l1
            self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)  #两帧cat求offset,所以输入通道为2*num_feat
            if i == 3:  #由于不用合并上一层(没有上一层)的帧,所以这里输入和输出通道都是num_feat
                self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
            else:  #否则,就先经过一次卷积将通道减半,变回num_feat。然后再用一次卷积(输入和输出通道都是num_feat)
                self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
                self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
            # 将本层最终的offset(这里说是最终offset,是说这是求offset前的最后操作了。感觉作者这里起名不太好,其实应该是feat.
            # 这个feat是根据参考帧与支持帧的feat)
            # 将这个offset(feat)与需要align的支持帧feat送入dcn
            # 真正的offset是在DCNv2Pack里面求的
            self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)

            if i < 3:
                self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)

        # Cascading dcn
        # 将最终生成的对齐feat,与L1层的参考帧特征cat.所以通道数加倍
        # 通过cas_offset_conv1。使得通道数变回num_feat
        self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
        # 通过cas_offset_conv2。通道数不变
        self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # 求解最终的对齐特征
        self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)

        #对L3和L2的特征与offset,都要进行2倍的上采样,用于和下一层进行拼接
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

    def forward(self, nbr_feat_l, ref_feat_l):
        """Align neighboring frame features to the reference frame features.

        Args:
            nbr_feat_l (list[Tensor]): Neighboring feature list. It
                contains three pyramid levels (L1, L2, L3),
                each with shape (b, c, h, w).
            ref_feat_l (list[Tensor]): Reference feature list. It
                contains three pyramid levels (L1, L2, L3),
                each with shape (b, c, h, w).

        Returns:
            Tensor: Aligned features.
        """
        # Pyramids
        upsampled_offset, upsampled_feat = None, None
        for i in range(3, 0, -1):  #L3,L2,L1
            level = f'l{i}'
            offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)  #将第i层的参考帧特征与支持帧特征在通道维度上拼接
            offset = self.lrelu(self.offset_conv1[level](offset))   #将拼接后的2*num_feat减半,变回num_feat
            if i == 3:    #如果是第三层,则直接求offset(其他层是要接受上一层的offset和特征,进行拼接的)
                offset = self.lrelu(self.offset_conv2[level](offset))  #shape不变,通道也不变
            else:         #不是第三层。以第二层为例,需要将上一层(即第三层)的offset上采样,然后与本层(第二层)的offset拼接。
                offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
                offset = self.lrelu(self.offset_conv3[level](offset))   #通道数为num_feat

            feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)  #将本层最终的用于求offset的feat与需要align的支持帧送入dcn
            if i < 3:
                #同样,如果是L2和L1层。需要将本层的通过dcn求得的align feature与上一层的align feature(上采样)拼接
                #拼接后,通道数加倍。则经过feat_conv,通道数变回num_feat
                #这样才是得到了本层最终的feature
                feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
            if i > 1:
                feat = self.lrelu(feat)

            if i > 1:  # upsample offset and features,采用pytorch提供的双线性插值
                # x2: when we upsample the offset, we should also enlarge the magnitude.
                # 当我们对偏移量进行上采样时,我们还应该放大幅度
                upsampled_offset = self.upsample(offset) * 2
                upsampled_feat = self.upsample(feat)

        # Cascading级联
        # 将最终生成的对齐feat,与L1层的参考帧特征cat
        offset = torch.cat([feat, ref_feat_l[0]], dim=1)
        offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
        feat = self.lrelu(self.cas_dcnpack(feat, offset))
        return feat

下面几点需要注意:

  1. 在前面的对齐网络的网络结构中,有这样一段,可以看到,这里的通道数可能与其他地方不一样
    (dcn_pack): ModuleDict(
      (l3): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (l2): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (l1): DCNv2Pack(
        (conv_offset): Conv2d(64, 216, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )

在conv_offset中,通道数是216。216怎么来的呢?216=298+9*8。写到这里,相信大家应该看出来了(当然,要熟悉DCN),这是offset和mask合起来的通道数。

2.在对齐网络的代码中,有这样一段:

            self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)  #两帧cat求offset,所以输入通道为2*num_feat
            if i == 3:  #由于不用合并上一层(没有上一层)的帧,所以这里输入和输出通道都是num_feat
                self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
            else:  #否则,就先经过一次卷积将通道减半,变回num_feat。然后再用一次卷积(输入和输出通道都是num_feat)
                self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
                self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
            offset = self.lrelu(self.offset_conv1[level](offset))   #将拼接后的2*num_feat减半,变回num_feat
            if i == 3:    #如果是第三层,则直接求offset(其他层是要接受上一层的offset和特征,进行拼接的)
                offset = self.lrelu(self.offset_conv2[level](offset))  #shape不变,通道也不变
            else:         #不是第三层。以第二层为例,需要将上一层(即第三层)的offset上采样,然后与本层(第二层)的offset拼接。
                offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
                offset = self.lrelu(self.offset_conv3[level](offset))   #通道数为num_feat

这里的offset_conv1/2/3,和上面的第1点中的conv_offset是不同的。这里的offset_conv1/2/3是普通的卷积操作,虽然带着offset,但是却不是求offset的卷积。那么为什么带着offset呢,个人认为:这里其实是求offset的特征(参考帧与支持帧特征进行cat,然后经过卷积求特征,以便用于求offset),是要带进DCN中求offset的,因此变量名字中带了个offset。
也就是下面的代码:

            self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
            feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)  #将本层最终的用于求offset的feat与需要align的支持帧送入dcn

3.好了,现在的重点到了DCNv2Pack这个类。类(PCDAlignment)套着类(DCNv2Pack),套娃真累啊,/(ㄒoㄒ)/~~
下面给出DCNv2Pack这个类的代码:

class DCNv2Pack(ModulatedDeformConvPack):
    """Modulated deformable conv for deformable alignment.

    继承了父类ModulatedDeformConvPack

    Different from the official DCNv2Pack, which generates offsets and masks
    from the preceding features, this DCNv2Pack takes another different
    features to generate offsets and masks.
    不同于官方的DCNv2Pack,它从前面的功能生成偏移和掩码,此DCNv2ack采用另一种不同的功能来生成偏移和掩码

    Ref:
        Delving Deep into Deformable Alignment in Video Super-Resolution.
    """

    # 将offset(feat)与需要align的支持帧送入dcn
    def forward(self, x, feat):
        out = self.conv_offset(feat)  #conv_offset在ModulatedDeformConvPack里面。由于继承了杜父类,因此可以调用
        #求offset和mask
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)

        #由于DCN训练的时候无法稳定的收敛,导致整个网络不好收敛,这是作者认识到的一个问题,也是这个模型存在的训练上的缺陷。
        #在实验中,如果训练不稳定则会输出 Offset mean is larger than 100(下面的代码是50)

        offset_absmean = torch.mean(torch.abs(offset))   #offset的绝对值的平均值,评判形变大小。太大的话,则一般是跑崩了
        if offset_absmean > 50:
            logger = get_root_logger()
            logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
            # 试着直接替换成下句
            # print(f'Offset abs mean is {offset_absmean}, larger than 50.')

        #获取torchvision的版本号,选择使用哪个DCN模型
        #我使用的torchvision版本是0.11.3
        if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
            #采用pytorch官方的API
            return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
                                                 self.dilation, mask)
        else:
            #采用Basicsr这个库实现的API
            return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
                                         self.dilation, self.groups, self.deformable_groups)

看(上面的代码),这个类(DCNv2Pack)还套着类(ModulatedDeformConvPack),麻了,想哭。我们把这个类的代码也放出来:

class ModulatedDeformConvPack(ModulatedDeformConv):
    """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
    一种用作正常Conv层的调制可变形Conv封装。为什么这样讲,看下面的代码,我认为说的是求offset和mask的过程是用普通(或正常)的卷积实现的

    Args:
        in_channels (int): Same as nn.Conv2d.
        out_channels (int): Same as nn.Conv2d.
        kernel_size (int or tuple[int]): Same as nn.Conv2d.
        stride (int or tuple[int]): Same as nn.Conv2d.
        padding (int or tuple[int]): Same as nn.Conv2d.
        dilation (int or tuple[int]): Same as nn.Conv2d.
        groups (int): Same as nn.Conv2d.
        bias (bool or str): If specified as `auto`, it will be decided by the
            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
            False.
    """

    _version = 2   #应该是指DCN的v2版本

    def __init__(self, *args, **kwargs):
        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)

        #这便是上面所说的一种用作正常Conv层的调制可变形Conv封装。在forward中可以看到用来求offset和mask。
        self.conv_offset = nn.Conv2d(
            self.in_channels,
            #2*8*3*3+1*8*3*3
            self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
            kernel_size=self.kernel_size,
            stride=_pair(self.stride),
            padding=_pair(self.padding),
            dilation=_pair(self.dilation),
            bias=True)
        self.init_weights()

    def init_weights(self):  #初始化权重
        super(ModulatedDeformConvPack, self).init_weights()
        # hasattr(object, name)
        # 判断object对象中是否存在name属性,当然对于python的对象而言,属性包含变量和方法;有则返回True,没有则返回False;
        # 需要注意的是name参数是string类型,所以不管是要判断变量还是方法,其名称都以字符串形式传参;getattr和setattr也同样;
        # https://zhuanlan.zhihu.com/p/411525142
        if hasattr(self, 'conv_offset'):
            #均初始化为0
            self.conv_offset.weight.data.zero_()
            self.conv_offset.bias.data.zero_()

    def forward(self, x):
        out = self.conv_offset(x)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
                                     self.groups, self.deformable_groups)

呢,这里还是类套着类,(注意,这里写的是类套的类,一开始觉得不专业,应该用继承。想了一下,继承反而是错的。上面的类套类中,有的是继承,有的是调用。)
我们先分析上面的两个类(DCNv2Pack 和 ModulatedDeformConvPack)。
(1)下面,先分析ModulatedDeformConvPack。这个类的核心,可以分为两部分:
①求offset、mask:forward方法的前四句
②进行DCN:最后的return modulated_deform_conv。在这个类里(ModulatedDeformConvPack)并没有对modulated_deform_conv这个类或函数的定义,原因自然很明显了,再往上继承一层/(ㄒoㄒ)/~~。也就是ModulatedDeformConvPack继承的ModulatedDeformConv里会有modulated_deform_conv函数。
(2)下面分析DCNv2Pack这个类:
可以看到DCNv2Pack这个类只写了forward方法,并没有出现构造函数。

    def forward(self, x, feat):
        out = self.conv_offset(feat)  #conv_offset在ModulatedDeformConvPack里面。由于继承了父类,因此可以调用
        #求offset和mask
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        ##……………………………………
        ##……………………………………

①forward方法中的第一句:conv_offset在ModulatedDeformConvPack里面。由于继承父类,因此这里是调用的父类中的方法。所以在实例化DCNv2Pack这个类时,

self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)

填入的参数是填给ModulatedDeformConvPack这个类的。
而且这里的conv_offset正是我们前面所讲到的第1点(通道的问题)。
②forward方法的第二部分: 由于DCN训练的时候无法稳定的收敛,导致整个网络不好收敛,这是作者认识到的一个问题,也是这个模型存在的训练上的缺陷。也是改进的一个方向。

        #在实验中,如果训练不稳定则会输出 Offset mean is larger than 100(下面的代码是50)

        offset_absmean = torch.mean(torch.abs(offset))   #offset的绝对值的平均值,评判形变大小。太大的话,则一般是跑崩了
        if offset_absmean > 50:
            logger = get_root_logger()
            logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
            # 试着直接替换成下句
            # print(f'Offset abs mean is {offset_absmean}, larger than 50.')

③forward的最后一部分,是选择使用哪个DCN模型:

        #获取torchvision的版本号,选择使用哪个DCN模型
        #我使用的torchvision版本是0.11.3
        if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
            #采用pytorch官方的API
            return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
                                                 self.dilation, mask)
        else:
            #采用Basicsr这个库实现的API
            return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
                                         self.dilation, self.groups, self.deformable_groups)

由于我使用的torchvision版本是0.11.3,因此运行时会选择Pytorch实现的关于DCN的API。 而BasicSR是一个基于 PyTorch 的开源图像视频复原工具箱,比如超分辨率, 去噪, 去模糊, 去 JPEG 压缩噪声等。实际上,我们这篇代码便是从BasicSR下载的。(了解的没错的话,EDVR的作者,就是这个工具箱的开发者)。
下面,要是再展开讲的话,就是讲Pytorch和BasicSR中对DCN的代码实现了。这里先打住,看以后的时间了(等待填坑hhh)。


好了,到目前,我们算是把DCNv2Pack这个类讲的很清楚了(当然,还差最后讲的待填的坑。其实一直想仔细地整理DCN的代码,无奈好多事。结果要在EDVR这些用到DCN的模型都整理好了,再去整理它)。
讲清楚了DCNv2Pack,也代表着把PCDAlignment这个模块讲清楚了。


2.2 融合模块

下面,先给出总的TSAFusion这个类的代码:
待填坑


3.参考链接

1.https://blog.csdn.net/MR_kdcon/article/details/124170948?spm=1001.2014.3001.5502
2. http://events.jianshu.io/p/05abb917ae57

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值