实验笔记之——基于CARN的Octave Convolution实验记录

本博文跟上一篇博文《实验笔记之——基于SRResNet的Octave Convolution实验记录》类似,只是改为CARN的结构。用octave convolution代替传统的convolution

给出CARN的代码https://github.com/nmhkahn/CARN-pytorch

论文https://arxiv.org/pdf/1803.08664.pdf

 

理论

作者主要围绕轻量级来论述本文。CARN(Cascading Residual Network,级联残差网络)

它具有以下三个特征:

  1. 全局和局部级联连接
  2. 中间特征是级联的,且被组合在1×1大小的卷积块中
  3. 使多级表示和快捷连接,让信息传递更高效

为了提升CARN的效率,作者提出了一种残差-E模块。就是将普通的Residual Block中conv换成了group conv。作者在这里提出了使用group conv而是不是 depthwise convolution,作者解释说group conv比 depthwise convolution可以更好的调整模型的有效性。关于group conv和depthwise conv可以参考博文《各种卷积层的理解(深度可分离卷积、分组卷积、扩张卷积、反卷积)》 为了进一步减少模型的参数,作者使用了recursive network,让网络的参数进行共享。

 

在局部和全局使用级联的好处如下:

  1. 结合多层的特征,可以学习到多水平的表示。
  2. 级联可以让低层的特征很快的传播到高层。

 

参考材料

https://blog.csdn.net/alxe_made/article/details/85839802

https://blog.csdn.net/Chaolei3/article/details/79374563

https://zhuanlan.zhihu.com/p/28749411

https://zybuluo.com/hanbingtao/note/626300

实验

CARN

residual block与E residual block

class BasicBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1):
        super(BasicBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
            nn.ReLU(inplace=True)
        )

        init_weights(self.modules)
        
    def forward(self, x):
        out = self.body(x)
        return out


class ResidualBlock(nn.Module):
    def __init__(self, 
                 in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
        )

        init_weights(self.modules)
        
    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out


class EResidualBlock(nn.Module):
    def __init__(self, 
                 in_channels, out_channels,
                 group=1):
        super(EResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 1, 1, 0),
        )

        init_weights(self.modules)
        
    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out

先复现CARN,给出网络结构如下:

##############################################################################################
#CARN
class CARN(nn.Module):#nb=3(3 block),channel=24
    def __init__(self, in_nc, out_nc, nf=24, nc=4, nb=3, alpha=0.75, upscale=4, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):
        super(CARN, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        self.nb = nb

        self.fea_conv =B.conv_block(in_nc, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

        self.CascadeBlocks = nn.ModuleList([B.CascadeBlock(nf, nf, kernel_size=3, norm_type=norm_type, \
            act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
        self.CatBlocks = nn.ModuleList([B.conv_block((i + 2)*nf, nf, kernel_size=1, \
            norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
        self.P_conv = B.conv_block(nf, in_nc*(upscale ** 2), kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

        self.upsampler = nn.PixelShuffle(upscale)



    def forward(self, x):
        x = self.fea_conv(x)
        pre_fea = x
        for i in range(self.nb):
            res = self.CascadeBlocks[i](x)
            pre_fea = torch.cat((pre_fea, res), dim=1)
            x = self.CatBlocks[i](pre_fea)
        x = self.P_conv(x)
        x = F.sigmoid(self.upsampler(x))
        return x


##############################################################################################
class CascadeBlock(nn.Module):
    """
    CascadeBlock, 3-3 style
    """

    def __init__(self, nc, gc, kernel_size=3, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(CascadeBlock, self).__init__()
        self.nc = nc
        self.ResBlocks = nn.ModuleList([ResNetBlock(gc, gc, gc, kernel_size, stride, dilation, groups, bias, \
            pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
        self.CatBlocks = nn.ModuleList([conv_block((i + 2)*gc, gc, kernel_size=1, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])

    def forward(self, x):
        pre_fea = x
        for i in range(self.nc):
            res = self.ResBlocks[i](x)
            pre_fea = torch.cat((pre_fea, res), dim=1)
            x = self.CatBlocks[i](pre_fea)
        return x

训练的setting

{
  "name": "carn_DIV2K_baseline" //  please remove "debug_" during training
  , "tb_logger_dir": "octave"
  , "use_tb_logger": true
  , "model":"sr"
  , "scale": 4
  , "crop_scale": 0
  , "gpu_ids": [0,1,2,3]
//  , "init_type": "kaiming"
//
//  , "finetune_type": "sft"
//  , "init_norm_type": "zero"

  , "datasets": {
    "train": {
      "name": "DIV2K800"
      , "mode": "LRHR"
      , "dataroot_HR": "/media/sdc/wpguan/BasicSR_datasets/DIV2K800_sub"
      , "dataroot_LR": "/media/sdc/wpguan/BasicSR_datasets/DIV2K800_sub_bicLRx4"
      , "subset_file": null
      , "use_shuffle": true
      , "n_workers": 8
      , "batch_size": 16 // 16
      , "HR_size": 128 // 128 | 192 | 96
      , "noise_gt": true
      , "use_flip": true
      , "use_rot": true
    }

  , "val": {
      "name": "set5"
      , "mode": "LRHR"
      , "dataroot_HR": "/media/sdc/wpguan/BasicSR_datasets/val_set5/Set5/"
      , "dataroot_LR": "/media/sdc/wpguan/BasicSR_datasets/val_set5/Set5_bicLRx4/"
      , "noise_gt": false
    }

  }

  , "path": {
    "root": "/home/wpguan/SR_master/octave"
    , "pretrain_model_G": null
  }


//
  , "network_G": {
    "which_model_G": "carn" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet  octave_resnet, octave_carn
//    , "norm_type": "adaptive_conv_res"
    , "norm_type": null
    , "mode": "CNA"
    , "nf": 24//24//64
    , "nb": 3//3//16
    , "in_nc": 3
    , "out_nc": 3
//    , "gc": 32
    , "group": 1
//    , "gate_conv_bias": true
//    , "ada_ksize": 1
//    , "num_classes": 2
  }


//    , "network_G": {
//    "which_model_G": "srcnn" // RRDB_net | sr_resnet
    , "norm_type": null
//    , "norm_type": "adaptive_conv_res"
//    , "mode": "CNA"
//    , "nf": 64
//    , "in_nc": 1
//    , "out_nc": 1
//    , "ada_ksize": 5
//  }


  , "train": {
//    "lr_G": 1e-3
    "lr_G": 8e-4
    , "lr_scheme": "MultiStepLR"
    , "lr_steps": [210000, 350000, 500000]
//    , "lr_steps": [500000]
    , "lr_gamma": 0.5


    , "pixel_criterion": "l2"

    , "pixel_criterion_reg": "tv"

    , "pixel_weight": 1.0
    , "val_freq": 1e3

    , "manual_seed": 0
    , "niter": 6e5
  }

  , "logger": {
    "print_freq": 200
    , "save_checkpoint_freq": 1e3
  }
}

结果:

 

 

 

octave_CARN

再给出octave_carn网络的结构:

##############################################################################################
#Octave CARN
class Octave_CARN(nn.Module):#nb=3(3 block),channel=24
    def __init__(self, in_nc, out_nc, nf=24, nc=4, nb=3, alpha=0.75, upscale=4, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):
        super(Octave_CARN, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        self.nb = nb

        self.fea_conv =B.conv_block(in_nc, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
        self.oct_first=B.FirstOctaveConv(nf, nf, kernel_size=3,  alpha=alpha, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
        self.CascadeBlocks = nn.ModuleList([B.OctaveCascadeBlock(nc, nf, kernel_size=3, alpha=alpha, \
            norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
        self.CatBlocks = nn.ModuleList([B.OctaveConv((i + 2)*nf, nf, kernel_size=1, alpha=alpha, \
            norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
        self.oct_last = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')

        self.upsampler = nn.PixelShuffle(upscale)
        self.HR_conv1 = B.conv_block(nf, in_nc*(upscale ** 2), kernel_size=3, norm_type=None, act_type=None)


    def forward(self, x):
        x = self.fea_conv(x)
        x = self.oct_first(x)
        pre_fea = x
        for i in range(self.nb):
            res = self.CascadeBlocks[i](x)
            pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
                        torch.cat((pre_fea[1], res[1]), dim=1))
            x = self.CatBlocks[i](pre_fea)
        x = self.oct_last(x)
        x = self.HR_conv1(x)
        x = F.sigmoid(self.upsampler(x))
        return x


##############################################################################################
####################
class OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.7, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride

        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        X_h, X_l = x

        if self.stride ==2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_h2h = self.h2h(X_h)
        X_l2h = self.upsample(self.l2h(X_l))
        X_l2l = self.l2l(X_l)
        X_h2l = self.h2l(self.h2g_pool(X_h))
        
        #print(X_l2h.shape,"~~~~",X_h2h.shape)
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l

        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l


class FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.7, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.stride = stride
        self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        if self.stride ==2:
            x = self.h2g_pool(x)

        X_h = self.h2h(x)
        X_l = self.h2l(self.h2g_pool(x))

        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l


class LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.7, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride

        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None

    def forward(self, x):
        X_h, X_l = x

        if self.stride ==2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        
        X_h2h = self.h2h(X_h)
        X_l2h = self.upsample(self.l2h(X_l))
        
        X_h = X_h2h + X_l2h

        if self.n_h:
            X_h = self.n_h(X_h)

        if self.a:
            X_h = self.a(X_h)

        return X_h

class OctaveCascadeBlock(nn.Module):
    """
    OctaveCascadeBlock, 3-3 style
    """
    def __init__(self, nc, gc, kernel_size=3, alpha=0.7, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(OctaveCascadeBlock, self).__init__()
        self.nc = nc
        self.ResBlocks = nn.ModuleList([OctaveResBlock(gc, gc, gc, kernel_size, alpha, stride, dilation, \
            groups, bias, pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
        self.CatBlocks = nn.ModuleList([OctaveConv((i + 2)*gc, gc, kernel_size=1, alpha=alpha, bias=bias, \
            pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])

    def forward(self, x):
        pre_fea = x
        for i in range(self.nc):
            res = self.ResBlocks[i](x)
            pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
                        torch.cat((pre_fea[1], res[1]), dim=1))
            x = self.CatBlocks[i](pre_fea)
        return x

给出setting:

{
  "name": "octave_carn_DIV2K_alpha0.75" //  please remove "debug_" during training
  , "tb_logger_dir": "octave"
  , "use_tb_logger": true
  , "model":"sr"
  , "scale": 4
  , "crop_scale": 0
  , "gpu_ids": [0,3]
//  , "init_type": "kaiming"
//
//  , "finetune_type": "sft"
//  , "init_norm_type": "zero"

  , "datasets": {
    "train": {
      "name": "DIV2K800"
      , "mode": "LRHR"
      , "dataroot_HR": "/media/sdc/wpguan/BasicSR_datasets/DIV2K800_sub"
      , "dataroot_LR": "/media/sdc/wpguan/BasicSR_datasets/DIV2K800_sub_bicLRx4"
      , "subset_file": null
      , "use_shuffle": true
      , "n_workers": 8
      , "batch_size": 16 // 16
      , "HR_size": 128 // 128 | 192 | 96
      , "noise_gt": true
      , "use_flip": true
      , "use_rot": true
    }

  , "val": {
      "name": "set5"
      , "mode": "LRHR"
      , "dataroot_HR": "/media/sdc/wpguan/BasicSR_datasets/val_set5/MSet5"
      , "dataroot_LR": "/media/sdc/wpguan/BasicSR_datasets/val_set5/MSet5_bicLRx4"
      , "noise_gt": false
    }

  }

  , "path": {
    "root": "/home/wpguan/SR_master/octave"
    , "pretrain_model_G": null
  }


//
  , "network_G": {
    "which_model_G": "octave_carn" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet  octave_resnet, octave_carn
//    , "norm_type": "adaptive_conv_res"
    , "norm_type": null
    , "mode": "CNA"
    , "nf": 24//64
    , "nb": 3//16
    , "in_nc": 3
    , "out_nc": 3
//    , "gc": 32
    , "group": 1
//    , "gate_conv_bias": true
//    , "ada_ksize": 1
//    , "num_classes": 2
  }


//    , "network_G": {
//    "which_model_G": "srcnn" // RRDB_net | sr_resnet
    , "norm_type": null
//    , "norm_type": "adaptive_conv_res"
//    , "mode": "CNA"
//    , "nf": 64
//    , "in_nc": 1
//    , "out_nc": 1
//    , "ada_ksize": 5
//  }


  , "train": {
//    "lr_G": 1e-3
    "lr_G": 6e-4
    , "lr_scheme": "MultiStepLR"
    , "lr_steps": [200000, 400000, 600000, 800000]
//    , "lr_steps": [500000]
    , "lr_gamma": 0.5


    , "pixel_criterion": "l2"

    , "pixel_criterion_reg": "tv"

    , "pixel_weight": 1.0
    , "val_freq": 1e3

    , "manual_seed": 0
    , "niter": 1e6
  }

  , "logger": {
    "print_freq": 200
    , "save_checkpoint_freq": 1e3
  }
}

结果如下图所示

 

实验python train.py -opt options/train/train_sr.json

先激活虚拟环境source activate pytorch

tensorboard --logdir tb_logger/ --port 6008

浏览器打开http://172.20.36.203:6008/#scalars
(拿到新的服务器~重新跑实验~)

 

 

 

 

 

 

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值