实验笔记之——基于SRResNet的Octave Convolution实验记录

先给出论文的链接(https://arxiv.org/pdf/1904.05049.pdf

github连接(https://github.com/terrychenism/OctaveConv/

本博文为文章《Drop an Octave Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution》的阅读笔记及基于SRResNet的Octave Convolution的实验与探索

Octave原理:

octconv就如同CNN的压缩器,代替传统的卷积,能在提升效果的同时,节约计算资源。比如说一个经典的图像识别算法,换掉其中的传统卷积,在ImageNet上的识别精度能获得1.2%的提升,同时,只需要82%的算力和91%的存储空间。如果对精度没有那么高的要求,和原来持平满足了的话,只需要一半的浮点运算能力就够了。(但貌似这种提升是在识别领域上的,因为可以扩大感受野。那么在超分任务上,感受野的大小是否有影响呢?)并且OctConv即插即用,无需修改网络架构,也不需要调整超参。

尺度空间理论

如果我们要处理的图像目标的大小/尺度(scale)是未知的,那么我们可以采用尺度空间理论。

其核心思想是将图像用多种尺度表示,这些表示统称为尺度空间表示(scale-space representation)
我们对图像用一系列高斯滤波器加以平滑,而这些高斯滤波器的尺寸是不同的
这样,我们就得到了该图像在不同尺度下的表示。

尺度空间方法最重要的属性是尺度不变性(scale invariant),使得我们可以处理未知大小的图像目标。

最后要注意的是,在构造尺度空间时,往往还伴随着降采样

关于这部分的描述可以参考本文之前的博客( 学习笔记之——vs2015+opencv2.4.13实现SIFT、SURF、ORB

 

论文

核心原理就是利用空间尺度化理论将图像高频低频部分分开,下采样低频部分,可以大大降低参数量,并且可以完美的嵌入到神经网络中。降低了低频信息的冗余。

CNNs生成的特征图在空间维度上也存在大量冗余,每个位置独立存储自己的特征描述符,忽略了可以一起存储和处理的相邻位置之间的公共信息

通过octave convolution 减少空间冗余度。不仅仅自然图像,在卷积层输出的特征图中也存在高低频分量。而低频分量的存在是冗余的,在编码过程可以节省。

在图像中,一般包含了低频信息(全局结构)和高频信息(细节)。CONV输出的feature map也相当于不同频率信息的混合。

在本文中通过频率来分解混合的feature map,设计Octave Convolution,以较低的空间分辨率存储和处理空间变化“较慢”的特征图,从而降低内存和计算成本。同时CctConv是一种单一、通用、即插即用的卷积单元,可以直接代替(普通)卷积,而无需对网络结构进行任何调整。它也是正交和互补的方法,建议更好的拓扑结构或减少信道冗余,如分组或深度卷积。通过简单将OctConv代替传统的卷积层,可以提升分类任务的精确度,同时可以减少存储空间与计算量

在自然图像中,信息以不同的频率传递,其中较高的频率通常用精细的细节编码,较低的频率通常用全局结构编码。同样,卷积层的输出特征图也可以看作是不同频率下信息的混合。

如图1(a)所示,自然图像可以分解为描述平稳变化结构的低空间频率分量和描述快速变化精细细节的高空间频率分量[1,12]。同样,我们认为卷积层的输出特征映射也可以分解为不同空间频率的特征,并提出了一种新的多频特征表示方法,将高频和低频特征映射存储到不同的组中,如图1(b)所示。通过相邻位置间的信息的共享,可以完全降低低频组的空间分辨率,减少空间冗余。如图1(c)所示。将张量特征图包含两个频率和一个octave部分,直接从低频图中提取信息,无需将其解码回高频。

a natural image can be decomposed into a low spatial frequency component that describes the smoothly changing structure and a high spatial frequency component that describes the rapidly changing fine details

the output featuremaps of a convolution layer can also be decomposed into features of different spatial frequencies

propose a novel multi-frequency feature representation which stores high- and low-frequency feature maps into different groups

作者通过OctConv来takes in feature maps containing tensors of two frequencies one octave apart, and extracts information directly from the low-frequency maps without the need of decoding it back to the high-frequency(如图d所示)

OctConv代替传统的卷积层,可以减少计算量与存储量,OctConv采用低频卷积处理对应的低频信息,并有效地扩大了原始像素空间的感受野,进而提升识别率(对于denoise任务,扩大感受野可能有利于去噪?)

 

对于普通的卷积层。所有的输入输出特征图都有相同的空间分辨率。作者认为,有一个feature map的子集,捕捉空间低频变化,并包含空间冗余信息。而为了减少空间冗余,作者引入了octave特征表示法。将feature map张量分解为低频与高频。空间尺度理论提供了一种方法来构建空间分辨率的尺度空间,以及定义一个octave作为空间尺寸的一个分割(将octave定义为空间维度除以2的幂)。通过一个octave来减少低频特征图的空间分辨率

 

octave feature representation

对于普通卷积,所有的输入和输出特征图具有相同的空间分辨率。然而,空间尺度模型认为自然图像可以分解为捕捉全局布局和粗结构的低频信号和捕捉精细细节的高频信号。作者认为有一个特征映射子集,它捕获空间低频变化,并包含空间冗余信息。

octave特征表示,它显式地将特征映射张量分解为对应于低频和高频的组。尺度空间理论为我们提供了一种创建空间分辨率尺度空间的原则方法,并将octave定义为空间维度除以2的幂。我们用这种方式定义了低频和高频空间,即将低频地物图的空间分辨率降低一个octave。

(注意此处的alpha,本博文的实验,主要是围绕着该参数进行)

Octave Convolution

所提出的octave feature representation特征表示方法减少了空间冗余,比原表示方法更加紧凑。然而,由于输入特征的空间分辨率不同,普通卷积不能直接对这种表示进行操作。绕过这个问题的一种简单方法是将低频部分X^{L}上采样到原始的空间分辨率,将它与X^{H}连接起来,然后进行卷积,这将导致额外的计算和内存开销,并减少压缩带来的所有节省。为了充分利用我们紧凑的多频特征表示,我们引入了Octave Convolution,它可以直接作用于因式张量X=\left\{X^{H}, X^{L}\right\},而不需要任何额外的计算或内存开销。

Octave Convolution 我们的设计目标是有效地处理相应频率张量中的低频和高频分量,同时使我们的octave特征表示的高频分量和低频分量之间能够有效地通信。设X, Y为因式分解的输入和输出张量。那么高和低频特征图的输出Y=\left\{Y^{H}, Y^{L}\right\}将由Y^{H}=Y^{H \rightarrow H}+Y^{L \rightarrow H}Y^{L}=Y^{L \rightarrow L}+Y^{H \rightarrow L}分别在Y^{A \rightarrow B}表示卷积更新从功能映射组B组。具体来说,Y^{H \rightarrow H}, Y^{L \rightarrow L}表示intra-frequency信息更新,而Y^{H \rightarrow L}, Y^{L \rightarrow H}表示inter-frequency沟通。

为了计算这些项,我们将卷积核W分成两个分量W=\left[W^{H}, W^{L}\right],分别负责与X^{H}X^{L}进行卷积。将各分量进一步划分为频率内分量和频率间分量:W^{H}=\left[W^{H \rightarrow H}, W^{L \rightarrow H}\right]W^{L}=\left[W^{L \rightarrow L}, W^{H \rightarrow L}\right],参数张量形状如图2(b)所示。

对于高频地形图,我们在位置(p, q)处进行计算,使用正则卷积对频率内的更新进行计算,对频率间的更新进行计算,我们可以将特征张量X^{L}上采样折叠成卷积,不需要显式计算和存储上采样的特征图,如下所示:

Y_{p, q}^{H}=Y_{p, q}^{H \rightarrow H}+Y_{p, q}^{L \rightarrow H} =\sum_{i, j \in \mathcal{N}_{k}} {W_{i+\frac{k-1}{2}, j+\frac{k-1}{2}}^{H \rightarrow H}}^\top \quad X_{p+i, q+j}^H +\sum_{i, j \in \mathcal{N}_{k}} {W_{i+\frac{k-1}{2}, j+\frac{k-1}{2}}^{L \rightarrow H} }^\top X^L_{\left(\left\lfloor\frac{p}{2}\right\rfloor+ i\right),\left(\left\lfloor\frac{q}{2}\right\rfloor+ j\right)}

\lfloor \cdot \rfloor 表示floor操作。同样,对于低频特征图,我们使用正则卷积计算频率内更新。注意,由于map在一个octave以下,卷积也是低频的w.r.t.,即高频坐标空间。对于频间通信,我们可以再次将特征张量X^{H}下采样折叠成卷积,如下所示:

Y_{p, q}^{L}=Y_{p, q}^{L \rightarrow L}+Y_{p, q}^{H \rightarrow L} =\sum_{i, j \in \mathcal{N}_{k}}{ W_{i+\frac{k-1}{2}, j+\frac{k-1}{2}}^{L \rightarrow L}}^\top \quad X_{p+i, q+j}^{L} +\sum_{i, j \in \mathcal{N}_{k}} {W_{i+\frac{k-1}{2}, j+\frac{k-1}{2}}^{H \rightarrow L}}^\top X^H_{(2 * p+0.5+i),(2 * q+0.5+j)}

将一个因子2乘以位置(p, q)进行向下采样,并进一步将位置移动半步,以确保向下采样的映射与输入保持良好的对齐。然而,由于X^{H}的索引只能是一个整数,我们可以将索引四舍五入到(2 * p+i, 2 * q+j),或者通过对所有4个相邻位置求平均值来近似(2 * p+0.5+i, 2 * q+0.5+j)。第一个也被称为条纹卷积第二个是平均池。正如我们在3.3节和图3中所讨论的,条纹卷积会导致失调;因此,在本文的其余部分,我们使用平均池来近似这个值。

Octave Convolution的一个有趣和有用的性质是低频特征图的具有更大的感受野。将低频部分X^{L}与k x k卷积核进行卷积,与普通卷积相比,有效地将感受野扩大了2倍。这进一步帮助每个OctConv层从遥远的位置捕获更多的上下文信息,并可能提高识别性能。

 

 

 

要实现OctConv,有以下两个关键步骤:

1、第一步,我们要获得输入通道(或图像)的线性尺度表示,称为Octave feature representation。所谓高频分量,是指不经过高斯滤波的原始通道(或图像);所谓低频分量,是指经过t=2t=2的高斯滤波得到的通道(或图像)。由于低频分量是冗余的,因此作者将低频分量的通道长/ 宽设置为高频分量通道长/ 宽的一半

在音乐中,Octave是八音阶的意思,隔一个八音阶,频率会减半;在这里,drop an octave就是通道尺寸减半的含义。
那么高频通道和低频通道比例是多少呢?作者设置了一个超参数α∈[0,1],表示低频通道的比例
在本文中,输入通道低频比例αin和输出通道低频比例αout设为相同。

设计目标是有效地处理相应频率张量中的低频和高频分量,同时使我们的octave特征表示的高频分量和低频分量之间能够有效地通信。由于输入输出的空间分辨率不一致,普通的CONV不能直接作用。可以将XL上采样,然后再与XH节后到一起进行卷积操作。这会导致额外的计算量与存贮量。因此,作者直接通过octave convolution作用于因子分解张量而不需要任何额外的计算或内存开销。

2、第二步

设X, Y为因式分解的输入和输出张量。那么高和低频特征图的输出Y=\left\{Y^{H}, Y^{L}\right\}将由Y^{H}=Y^{H \rightarrow H}+Y^{L \rightarrow H}Y^{L}=Y^{L \rightarrow L}+Y^{H \rightarrow L}分别在Y^{A \rightarrow B}表示卷积更新从功能映射组B组。具体来说,Y^{H \rightarrow H}, Y^{L \rightarrow L}表示intra-frequency信息更新,而Y^{H \rightarrow L}, Y^{L \rightarrow H}表示inter-frequency沟通。

为了计算这些项,我们将卷积核W分成两个分量W=\left[W^{H}, W^{L}\right],分别负责与X^{H}X^{L}进行卷积。将各分量进一步划分为频率内分量和频率间分量:W^{H}=\left[W^{H \rightarrow H}, W^{L \rightarrow H}\right]W^{L}=\left[W^{L \rightarrow L}, W^{H \rightarrow L}\right],参数张量形状如图2(b)所示。

其中

具体方法很简单,就是取值的问题:


降采样后卷积相当于有步长的卷积,会不太精确;因此作者最后选择了平均池化(pooling)的方式,平均取值,采样结果会较精确一些。完整流程如下图所示。

本文提出将 特征图分为两组:低频特征(蓝色)和 高频特征(橙红),并将空间上变化较为缓慢的「低频特征图」存储在低分辨率的张量中,共享相邻位置间的特征。而本文所提出的 OctConv 则是一种可以直接作用在该特征表达下的卷积操作。它包含每个频率自身状态的更新(绿色箭头),以及频率间的信息交互(红色箭头)

我们可以发现,这种滤波+新式卷积的操作是“插片式”的,不需要破坏原来的CNN框架。值得注意的是,低频通道卷积的感受野比传统卷积更大。通过调整低频比例α,预测精度和计算代价可以得到权衡(trade-off)。

1)OctConv可以帮助CNNs提高准确性,而减少FLOPs,与其他方法不同的是,这些方法以较低的精度为代价来减少故障。

2)在测试时,OctConv比基线模型的增益随着测试图像分辨率的增加而增加,因为OctConv的接受域较大,可以更好地检测大对象,

 

参考资料

https://blog.csdn.net/weixin_37993251/article/details/89333099

http://www.cnblogs.com/RyanXing/p/10720182.html

https://www.zhihu.com/question/320462422/answer/655569703

https://blog.csdn.net/weixin_37993251/article/details/89333099

OctConv的motivation很有意思,通过分解图像高频成分和低频成分并作一定的融合得到multi-frequency的feature representation,既可以丰富特征表示,又可以降低特征冗余(高维度大分辨率的特征确实应该存在大量冗余,不过自己没验证过),当然在参数量和计算量上也比较有优势,毕竟在low frequency(low resolution)的feature上,计算量会减75% (通道不变的情况下)

将feature分成2个scale,去做4个conv(Low->Low, High->High, Low->High, High->Low ) 然后fuse到一起,得到multi-scale representation。

 

实验:(基于SRResNet进行重塑)

SRRESNET结构如下(但本博文中用到的结构稍微有点不一样)

本part中,基于上面提及的octave convolution,将其用于超分中,并测试结构。给出网络结构如下所示。

做乘4的超分。先用bicubic进行降采样

首先定义一个基于ResNet block的octave block

class OctaveResBlock(nn.Module):
    '''
    ResNet Block, 3-3 style
    with extra residual scaling used in EDSR
    (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
    '''
    def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, alpha=0.5, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(OctaveResBlock, self).__init__()
        conv0 = OctaveConv(in_nc, mid_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
            norm_type, act_type, mode)
        if mode == 'CNA':
            act_type = None
        if mode == 'CNAC':  # Residual path: |-CNAC-|
            act_type = None
            norm_type = None
        conv1 = OctaveConv(mid_nc, out_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
            norm_type, act_type, mode)

        self.res = sequential(conv0, conv1)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.res(x)
        res = (res[0].mul(self.res_scale), res[1].mul(self.res_scale))
        x = (x[0] + res[0], x[1] + res[1])
        return x + res

里面用到octaveConv,再定义OctaveConv

class OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride

        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        X_h, X_l = x

        if self.stride ==2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_h2h = self.h2h(X_h)
        X_l2h = self.upsample(self.l2h(X_l))
        X_l2l = self.l2l(X_l)
        X_h2l = self.h2l(self.h2g_pool(X_h))
        
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l

        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l

然后改写SRResNet有(实际上就是把srresnet中的conv改为octconv):

##############################################################################################
class Octave_SRResNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):
        super(Octave_SRResNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = B.OctaveConv(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
        resnet_blocks = [B.OctaveResBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\
            mode=mode, res_scale=res_scale) for _ in range(nb)]
        LR_conv = B.OctaveConv(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.OctaveConv(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.OctaveConv(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x

修改对应的model文件

    elif which_model == 'octave_resnet':  # SRResNet
        netG = arch.Octave_SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

setting如下:

{
  "name": "octave_srresnet_DIV2K" //  please remove "debug_" during training
  , "tb_logger_dir": "octave"
  , "use_tb_logger": true
  , "model":"sr"
  , "scale": 4
  , "crop_scale": 0
  , "gpu_ids": [2]
//  , "init_type": "kaiming"
//
//  , "finetune_type": "sft"
//  , "init_norm_type": "zero"

  , "datasets": {
    "train": {
      "name": "DIV2K800"
      , "mode": "LRHR"
      , "dataroot_HR": "/media/sdc/wpguan/BasicSR_datasets/DIV2K800_sub"
      , "dataroot_LR": "/media/sdc/wpguan/BasicSR_datasets/DIV2K800_sub_bicLRx4"
      , "subset_file": null
      , "use_shuffle": true
      , "n_workers": 8
      , "batch_size": 16 // 16
      , "HR_size": 128 // 128 | 192 | 96
      , "noise_gt": true
      , "use_flip": true
      , "use_rot": true
    }

  , "val": {
      "name": "set5"
      , "mode": "LRHR"
      , "dataroot_HR": "/media/sdc/wpguan/BasicSR_datasets/val_set5/MSet5"
      , "dataroot_LR": "/media/sdc/wpguan/BasicSR_datasets/val_set5/MSet5_bicLRx4"
      , "noise_gt": false
    }

  }

  , "path": {
    "root": "/home/wpguan/SR_master/octave"
    , "pretrain_model_G": null
  }


//
  , "network_G": {
    "which_model_G": "octave_resnet" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet  octave_resnet
//    , "norm_type": "adaptive_conv_res"
    , "norm_type": null
    , "mode": "CNA"
    , "nf": 64
    , "nb": 16
    , "in_nc": 3
    , "out_nc": 3
//    , "gc": 32
    , "group": 1
//    , "gate_conv_bias": true
//    , "ada_ksize": 1
//    , "num_classes": 2
  }


//    , "network_G": {
//    "which_model_G": "srcnn" // RRDB_net | sr_resnet
    , "norm_type": null
//    , "norm_type": "adaptive_conv_res"
//    , "mode": "CNA"
//    , "nf": 64
//    , "in_nc": 1
//    , "out_nc": 1
//    , "ada_ksize": 5
//  }


  , "train": {
//    "lr_G": 1e-3
    "lr_G": 6e-4
    , "lr_scheme": "MultiStepLR"
    , "lr_steps": [200000, 400000, 600000, 800000]
//    , "lr_steps": [500000]
    , "lr_gamma": 0.5


    , "pixel_criterion": "l2"

    , "pixel_criterion_reg": "tv"

    , "pixel_weight": 1.0
    , "val_freq": 2e3

    , "manual_seed": 0
    , "niter": 1e6
  }

  , "logger": {
    "print_freq": 100
    , "save_checkpoint_freq": 2e3
  }
}

代码纠正如下:(并且附上调试过程,有时候一丢丢的代码错误可能会导致卡好长一段时间。。。。。。)

####################
# Block for OctConv
####################
class OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride

        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        X_h, X_l = x

        if self.stride ==2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_h2h = self.h2h(X_h)
        X_l2h = self.upsample(self.l2h(X_l))
        X_l2l = self.l2l(X_l)
        X_h2l = self.h2l(self.h2g_pool(X_h))
        
        #print(X_l2h.shape,"~~~~",X_h2h.shape)
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l

        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l


class FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.stride = stride
        self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        if self.stride ==2:
            x = self.h2g_pool(x)

        X_h = self.h2h(x)
        X_l = self.h2l(self.h2g_pool(x))

        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l


class LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride

        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None

    def forward(self, x):
        X_h, X_l = x

        if self.stride ==2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        
        X_h2h = self.h2h(X_h)
        X_l2h = self.upsample(self.l2h(X_l))
        
        X_h = X_h2h + X_l2h

        if self.n_h:
            X_h = self.n_h(X_h)

        if self.a:
            X_h = self.a(X_h)

        return X_h

class OctaveCascadeBlock(nn.Module):
    """
    OctaveCascadeBlock, 3-3 style
    """
    def __init__(self, nc, gc, kernel_size=3, alpha=0.5, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(OctaveCascadeBlock, self).__init__()
        self.nc = nc
        self.ResBlocks = nn.ModuleList([OctaveResBlock(gc, gc, gc, kernel_size, alpha, stride, dilation, \
            groups, bias, pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
        self.CatBlocks = nn.ModuleList([OctaveConv((i + 2)*gc, gc, kernel_size=1, alpha=alpha, bias=bias, \
            pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])

    def forward(self, x):
        pre_fea = x
        for i in range(self.nc):
            res = self.ResBlocks[i](x)
            pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
                        torch.cat((pre_fea[1], res[1]), dim=1))
            x = self.CatBlocks[i](pre_fea)
        return x

class OctaveResBlock(nn.Module):
    '''
    ResNet Block, 3-3 style
    with extra residual scaling used in EDSR
    (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
    '''
    def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, alpha=0.5, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(OctaveResBlock, self).__init__()
        conv0 = OctaveConv(in_nc, mid_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
            norm_type, act_type, mode)
        if mode == 'CNA':
            act_type = None
        if mode == 'CNAC':  # Residual path: |-CNAC-|
            act_type = None
            norm_type = None
        conv1 = OctaveConv(mid_nc, out_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
            norm_type, act_type, mode)

        self.res = sequential(conv0, conv1)
        self.res_scale = res_scale

    def forward(self, x):
        #if(len(x)>2):
            #print(x[0].shape,"  ",x[1].shape,"  ",x[2].shape,"  ",x[3].shape)
        #print(len(x))
        res = self.res(x)
        res = (res[0].mul(self.res_scale), res[1].mul(self.res_scale))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        return (x[0] + res[0], x[1]+res[1])
##############################################################################################
#Ocatve SRResNet
class Octave_SRResNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):
        super(Octave_SRResNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = B.FirstOctaveConv(in_nc, nf, kernel_size=3,  alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
        #first=B.FirstOctaveConv(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
        resnet_blocks = [B.OctaveResBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\
            mode=mode, res_scale=res_scale) for _ in range(nb)]
        #last=B.LastOctaveConv(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
        LR_conv = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv, *resnet_blocks, LR_conv,\
            *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x


#####################################################################################################

实验python train.py -opt options/train/train_sr.json

先激活虚拟环境source activate pytorch

tensorboard --logdir tb_logger/ --port 6008

浏览器打开http://172.20.36.203:6008/#scalars

结果(将alpha改变分别进行实验)

当α=0 时(即没有低频成分),OctConv 就会退化为普通卷积。注意,无论比例α选择是多少,OctConv 的参数数量都是与常规卷机一致的

 

 

 

 

 

 

 

 

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值