实验记录之——SRResNet、CARN、RRDB、RCAN复现代码

 
python train.py -opt options/train/train_sr.json
python3 test.py -opt options/test/test_sr.json
 
source activate pytorch
tensorboard --logdir tb_logger/ --port 6008
 
http://172.20.36.203:6008/#scalars

##############################################################################################
#octave_srresnet
class highorder_SRResNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):
        super(highorder_SRResNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        alpha=0.5

        #self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
        self.fea_conv = B.FirstOctaveConv(in_nc, nf, kernel_size=3, alpha=alpha,norm_type=None, act_type='relu', mode='CNA')
        
        # # # ##base line
        # fea_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
        # resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')

        # # # # # # # # #octave layer(Layer_HRLRadd)
        # #fea_conv1 = B.FirstOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
        fea_conv1 = B.OctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
        # # # # # #fea_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
        #resnet_blocks = [B.OctaveResBlock(nf, nf, nf, kernel_size=3,alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # # # # # # # # # # # # ####LRaddHR(keepLR)
        #resnet_blocks = [B.LRaddHR_keepLR_OctaveResBlock(nf, nf, nf, kernel_size=3,alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # # # # # # # # # #without connection
        #resnet_blocks = [B.withoutconnection_OctaveResBlock(nf, nf, nf, kernel_size=3,alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # # # # ############HR add to LR
        resnet_blocks = [B.HRaddLR_keepHR_OctaveResBlock(nf, nf, nf, kernel_size=3,alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # #LR_conv = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
        LR_conv = B.OctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
        # # # # # #LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')


        # two branch(Network_HRLRadd)
        # self.fea_conv1 = B.FirstOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
        # #resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # resnet_blocks_LR1 = [B.ResNetBlock(int(alpha*nf), int(alpha*nf), int(alpha*nf), norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # resnet_blocks_HR1 = [B.ResNetBlock(nf-int(alpha*nf), nf-int(alpha*nf), nf-int(alpha*nf), norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # self.LR_conv = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# ########################################################
#         self.fea_conv1 = B.OctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
#         resnet_blocks_LR1 = [B.ResNetBlock(int(alpha*nf), int(alpha*nf), int(alpha*nf), norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
#         resnet_blocks_HR1 = [B.ResNetBlock(nf-int(alpha*nf), nf-int(alpha*nf), nf-int(alpha*nf), norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
#         self.LR_conv = B.OctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# ########################################################



        # # # # # ##block
        # fea_conv1 = B.FirstOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
        
        ######################################only output of the RD block 
        #resnet_blocks = [B.out_HRLRadd_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # # # #resnet_blocks = [B.out_HRLRattention_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # # # # # # ########################only output of the RD block (keepLR)
        #resnet_blocks = [B.out_LRaddHR_keepLR_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        #######################################only output of the RD block (keepLR)
        #resnet_blocks = [B.out_HRaddLR_keepHR_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # # # resnet_blocks = [B.out_LRattentionHR_keepLR_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
        # LR_conv = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
        # # # ############ LR acts on HR
        # # # fea_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
        # LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')


###################################################################################################################################################

        if upsample_mode == 'upconv':
            # upsample_block = B.upconv_blcok
            upsample_block = B.octave_upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            # upsample_block = B.pixelshuffle_block
            upsample_block = B.octave_pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, alpha=alpha,act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf,alpha=alpha, act_type=act_type) for _ in range(n_upscale)]
        #HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv0 = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type=act_type)
        
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
###################################################################################################################################################
        #self.model = B.ShortcutBlock(B.sequential(fea_conv1,*resnet_blocks, LR_conv))
        self.model = B.Octave_ShortcutBlock(B.sequential(fea_conv1,*resnet_blocks, LR_conv))
        # self.model_LR1 = B.sequential(*resnet_blocks_LR1)
        # self.model_HR1 = B.sequential(*resnet_blocks_HR1)


        self.subpixel_up  = B.sequential(*upsampler, HR_conv0, HR_conv1)


    def forward(self, x):
        
        x=self.fea_conv(x)
        # res=x
        # x=self.fea_conv1(x)
        # x_h,x_l=x
        # x_h=self.model_HR1 (x_h)
        # x_l=self.model_LR1 (x_l)

        # x=x_h,x_l
        # x=self.LR_conv(x)
        # x=res+x

        x = self.model(x)
        x=self.subpixel_up(x)

        return x
#######################################################################

    # def forward(self, x):
        
    #     x=self.fea_conv(x)
    #     res=x
    #     x=self.fea_conv1(x)
    #     x_h,x_l=x
    #     x_h=self.model_HR1 (x_h)
    #     x_l=self.model_LR1 (x_l)

    #     x=x_h,x_l
    #     x=self.LR_conv(x)
    #     x=(res[0]+x[0],res[1]+x[1])

    #     x=self.subpixel_up(x)

    #     return x

############################################################################################

RCAN复现版本

###############################################################################################################
#RCAN
class RCAN(nn.Module):
    def __init__(self, in_nc, out_nc, nf=64, ng=10, nb=20, reduction=16, upscale=4,alpha=0.75, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):
        super(RCAN, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
        CA_blocks = [B.ResidualGroupBlock(nf, nb, kernel_size=3, reduction=reduction, norm_type=norm_type, \
                                act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(ng)]
        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)


        
        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*CA_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1)




    def forward(self, x):
        x = self.model(x)
        return x

#################################################################################################################################

RRDBNet

class RRDB(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)
        self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)
        self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out.mul(0.2) + x
class ResidualDenseBlock_5C(nn.Module):
    '''
    Residual Dense Block
    style: 5 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    '''

    def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(ResidualDenseBlock_5C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=last_act, mode=mode)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5.mul(0.2) + x

 

##########################################################################################################
class channel_attention_OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA',reduction=8):
        super(channel_attention_OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0

        self.stride = stride
        self.out=out_nc
        self.alpha=alpha

        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

        self.average=nn.AdaptiveAvgPool2d(1)

        self.attention = sequential(
                conv_block(in_nc, in_nc // reduction, 1, stride, dilation, groups, bias, pad_type, \
                            norm_type, act_type, mode),
                conv_block(in_nc // reduction, out_nc, 1, stride, dilation, groups, bias, pad_type, \
                            norm_type, None, mode),
                nn.Sigmoid())

    def forward(self, x):
        X_h, X_l = x



        X_h2h = self.h2h(X_h)
        X_l2l = self.l2l(X_l)
        X_hpooling=self.average(X_h2h)
        X_lpooling=self.average(X_l2l)
        X_attention=torch.cat((X_hpooling, X_lpooling), dim=1)
        X_attention=self.attention(X_attention)
        X_h_attention=X_attention[:,:self.out - int(self.alpha * self.out),:,:]
        X_l_attention=X_attention[:,self.out - int(self.alpha * self.out):,:,:]

        #print(X_l2h.shape,"~~~~",X_h2h.shape)
        X_h = X_h2h*X_h_attention
        X_l = X_l2l*X_l_attention

        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l
################################################################################################

CARN

#################################################################################################################################
class CARN(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=64, nc=4, nb=3, upscale=2, norm_type=None, \
        act_type='prelu', mode='NAC', res_scale=1.0,upsample_mode='upconv'):
        super(CARN, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        self.nb = nb

        #alpha=0.5
        self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
        self.fea_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
        ###############################
        #self.fea_conv = B.FirstOctaveConv(in_nc, nf, kernel_size=3, alpha=alpha,norm_type=None, act_type='relu', mode='CNA')


        self.CascadeBlocks = nn.ModuleList([B.CascadeBlock(nc, nf, kernel_size=3, norm_type=norm_type, \
            act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
        ########octave
        # self.CascadeBlocks = nn.ModuleList([B.OctaveCascadeBlock(nc, nf, kernel_size=3, alpha=alpha, \
        #     norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
        ##########without connection
        # self.CascadeBlocks = nn.ModuleList([B.withoutconnection_OctaveCascadeBlock(nc, nf, kernel_size=3, alpha=alpha, \
        #     norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])




        self.CatBlocks = nn.ModuleList([B.conv_block((i + 2)*nf, nf, kernel_size=1, \
            norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
        ##########octave
        # self.CatBlocks = nn.ModuleList([B.OctaveConv((i + 2)*nf, nf, kernel_size=1, alpha=alpha, \
        #     norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
        ##########without connection
        # self.CatBlocks = nn.ModuleList([B.withoutconnection_OctaveConv((i + 2)*nf, nf, kernel_size=1, alpha=alpha, \
        #     norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])


        if upsample_mode == 'upconv':
            #upsample_block = B.upconv_blcok
            upsample_block = B.octave_upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
            #upsample_block = B.octave_pixelshuffle_block##############************************************
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3,act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf,act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        #HR_conv0 = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha,norm_type=None, act_type=act_type)
        
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.subpixel_up  = B.sequential(*upsampler, HR_conv0,HR_conv1)
        


    def forward(self, x):
        x = self.fea_conv(x)
        x=self.fea_conv1(x)

        pre_fea = x
        for i in range(self.nb):
            res = self.CascadeBlocks[i](x)
            pre_fea = torch.cat((pre_fea, res), dim=1)
            x = self.CatBlocks[i](pre_fea)

        # pre_fea = x
        # for i in range(self.nb):
        #     res = self.CascadeBlocks[i](x)
        #     pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
        #                 torch.cat((pre_fea[1], res[1]), dim=1))
        #     x = self.CatBlocks[i](pre_fea)


        x = self.subpixel_up(x)
        return x

#############################################################################################
####################################################################################################
class CascadeBlock(nn.Module):
    """
    CascadeBlock, 3-3 style
    """

    def __init__(self, nc, gc, kernel_size=3, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(CascadeBlock, self).__init__()
        self.nc = nc
        self.ResBlocks = nn.ModuleList([ResNetBlock(gc, gc, gc, kernel_size, stride, dilation, groups, bias, \
            pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
        self.CatBlocks = nn.ModuleList([conv_block((i + 2)*gc, gc, kernel_size=1, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])

    def forward(self, x):
        pre_fea = x
        for i in range(self.nc):
            res = self.ResBlocks[i](x)
            pre_fea = torch.cat((pre_fea, res), dim=1)
            x = self.CatBlocks[i](pre_fea)
        return x

 

 

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值