实验笔记之——octave layer(4路数据)

本博文为将octave layer改为四路数据的实验记录

实验python train.py -opt options/train/train_sr.json

先激活虚拟环境source activate pytorch

tensorboard --logdir tb_logger/ --port 6008

浏览器打开http://172.20.36.203:6008/#scalars
 

首先给出代码

原版代码

##################################################################################
##################################################################################
#modified octave
# Block for OctConv
####################
class M_NP_OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(M_NP_OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.upsample = nn.Upsample(scale_factor=4, mode='nearest')#double pool
        self.stride = stride
 
        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
 
    def forward(self, x):
        X_h, X_l = x
 
        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
 
        X_h2h = self.h2h(X_h)
        #X_l2h = self.upsample(self.l2h(X_l))
        #X_l2h = self.l2h(X_l)
        X_l2h = self.upsample(self.l2h(X_l))
 
        X_l2l = self.l2l(X_l)
        #X_h2l = self.h2l(self.h2g_pool(X_h))
        #X_h2l = self.h2l(X_h)
        X_h2l = self.h2l(self.h2g_pool2(self.h2g_pool(X_h)))
        
        #print(X_l2h.shape,"~~~~",X_h2h.shape)
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l
 
        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)
 
        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)
 
        return X_h, X_l
 
 
class M_NP_FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(M_NP_FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.stride = stride
        ###low frequency
        self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        ###high frequency
        self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
 
    def forward(self, x):
        #if self.stride ==2:
            #x = self.h2g_pool(x)
 
        X_h = self.h2h(x)
        #X_l = self.h2l(self.h2g_pool(x))
        #X_l = self.h2l(x)#without pool
        X_l = self.h2l(self.h2g_pool2(self.h2g_pool(x)))#double pool
 
        if self.n_h and self.n_l:##batch norm
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)
 
        if self.a:#Activation layer
            X_h = self.a(X_h)
            X_l = self.a(X_l)
 
        return X_h, X_l
 
 
class M_NP_LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(M_NP_LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2)
        self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
        self.stride = stride
 
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
 
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None
 
    def forward(self, x):
        X_h, X_l = x
 
        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        
        X_h2h = self.h2h(X_h)
        X_l2h = self.upsample(self.l2h(X_l))
        #X_l2h = self.l2h(X_l)
        
        X_h = X_h2h + X_l2h
 
        if self.n_h:
            X_h = self.n_h(X_h)
 
        if self.a:
            X_h = self.a(X_h)
 
        return X_h
 
class M_NP_octave_ResidualDenseBlockTiny_4C(nn.Module):
    '''
    Residual Dense Block
    style: 4 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    '''
 
    def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(M_NP_octave_ResidualDenseBlockTiny_4C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 =M_NP_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = M_NP_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = M_NP_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv4 = M_NP_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=last_act, mode=mode)
 
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),(torch.cat((x[1], x1[1]), dim=1))))
        x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),(torch.cat((x[1], x1[1],x2[1]), dim=1))))
        x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),(torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1))))
 
        res = (x4[0].mul(0.2), x4[1].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))
 
        #return (x[0] + res[0], x[1]+res[1])
        return x
 
 
class M_NP_octave_RRDBTiny(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''
 
    def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(M_NP_octave_RRDBTiny, self).__init__()
        self.RDB1 = M_NP_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.RDB2 = M_NP_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
 
    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
 
        res = (out[0].mul(0.2), out[1].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))
 
        #return (x[0] + res[0], x[1]+res[1])
        return x

 
class M_NP_OctaveCascadeBlock(nn.Module):
    """
    OctaveCascadeBlock, 3-3 style
    """
    def __init__(self, nc, gc, kernel_size=3, alpha=0.75, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(M_NP_OctaveCascadeBlock, self).__init__()
        self.nc = nc
        self.ResBlocks = nn.ModuleList([M_NP_OctaveResBlock(gc, gc, gc, kernel_size, alpha, stride, dilation, \
            groups, bias, pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
        self.CatBlocks = nn.ModuleList([M_NP_OctaveConv((i + 2)*gc, gc, kernel_size=1, alpha=alpha, bias=bias, \
            pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])

    def forward(self, x):
        pre_fea = x
        for i in range(self.nc):
            res = self.ResBlocks[i](x)
            pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
                        torch.cat((pre_fea[1], res[1]), dim=1))
            x = self.CatBlocks[i](pre_fea)
        return x

class M_NP_OctaveResBlock(nn.Module):
    '''
    ResNet Block, 3-3 style
    with extra residual scaling used in EDSR
    (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
    '''
    def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, alpha=0.75, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(M_NP_OctaveResBlock, self).__init__()
        conv0 = M_NP_OctaveConv(in_nc, mid_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
            norm_type, act_type, mode)
        if mode == 'CNA':
            act_type = None
        if mode == 'CNAC':  # Residual path: |-CNAC-|
            act_type = None
            norm_type = None
        conv1 = M_NP_OctaveConv(mid_nc, out_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
            norm_type, act_type, mode)

        self.res = sequential(conv0, conv1)
        self.res_scale = res_scale

    def forward(self, x):
        #if(len(x)>2):
            #print(x[0].shape,"  ",x[1].shape,"  ",x[2].shape,"  ",x[3].shape)
        #print(len(x))
        res = self.res(x)
        res = (res[0].mul(self.res_scale), res[1].mul(self.res_scale))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x

改进后代码

##################################################################################
##################################################################################
##################################################################################
#modified octave
# Block for OctConv
####################
class M_NP_OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(M_NP_OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2g_pool3 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')## pool
        self.upsample2 = nn.Upsample(scale_factor=4, mode='nearest')##double pool
        self.upsample3 = nn.Upsample(scale_factor=8, mode='nearest')##triple pool
        self.stride = stride

        # self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
        #                         kernel_size, 1, padding, dilation, groups, bias)
        # self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
        #                         kernel_size, 1, padding, dilation, groups, bias)
        # self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
        #                         kernel_size, 1, padding, dilation, groups, bias)
        # self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
        #                         kernel_size, 1, padding, dilation, groups, bias)
#########for h
        self.h2h = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), int(0.5*(out_nc - int(alpha * out_nc))),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l12h = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), int(0.5*(out_nc - int(alpha * out_nc))),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l22h = nn.Conv2d(int(0.5*int(alpha * in_nc)), int(0.5*(out_nc - int(alpha * out_nc))),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l32h = nn.Conv2d(int(0.5*int(alpha * in_nc)), int(0.5*(out_nc - int(alpha * out_nc))),
                                kernel_size, 1, padding, dilation, groups, bias)
###########for l1
        self.h2l1 = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), int(0.5*(out_nc - int(alpha * out_nc))),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l12l1 = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), int(0.5*(out_nc - int(alpha * out_nc))),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l22l1 = nn.Conv2d(int(0.5*int(alpha * in_nc)), int(0.5*(out_nc - int(alpha * out_nc))),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l32l1 = nn.Conv2d(int(0.5*int(alpha * in_nc)), int(0.5*(out_nc - int(alpha * out_nc))),
                                kernel_size, 1, padding, dilation, groups, bias)
###########for l2
        self.h2l2 = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), int(0.5*int(alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l12l2 = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), int(0.5*int(alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l22l2 = nn.Conv2d(int(0.5*int(alpha * in_nc)), int(0.5*int(alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l32l2 = nn.Conv2d(int(0.5*int(alpha * in_nc)), int(0.5*int(alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
###########for l3
        self.h2l3 = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), int(0.5*int(alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l12l3 = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), int(0.5*int(alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l22l3 = nn.Conv2d(int(0.5*int(alpha * in_nc)), int(0.5*int(alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l32l3 = nn.Conv2d(int(0.5*int(alpha * in_nc)), int(0.5*int(alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)



        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
 
    def forward(self, x):
        X_h, X_l1,X_l2,X_l3 = x
 
        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
 
 #######for h
        X_h2h = self.h2h(X_h)
        X_l12h=self.l12h(self.upsample(X_l1))
        X_l22h=self.l22h(self.upsample2(X_l2))
        X_l32h=self.l32h(self.upsample3(X_l3))

 #######for l1
        X_h2l1 = self.h2l1(self.h2g_pool(X_h))
        X_l12l1=self.l12l1(X_l1)
        X_l22l1=self.l22l1(self.upsample(X_l2))
        X_l32l1=self.l32l1(self.upsample2(X_l3))

#######for l2
        X_h2l2 = self.h2l2(self.h2g_pool2(self.h2g_pool(X_h)))
        X_l12l2=self.l12l2(self.h2g_pool(X_l1))
        X_l22l2=self.l22l2(X_l2)
        X_l32l2=self.l32l2(self.upsample(X_l3))

#######for l3
        X_h2l3 = self.h2l3(self.h2g_pool3(self.h2g_pool2(self.h2g_pool(X_h))))
        X_l12l3=self.l12l3(self.h2g_pool2(self.h2g_pool(X_l1)))
        X_l22l3=self.l22l3(self.h2g_pool(X_l2))
        X_l32l3=self.l32l3(X_l3)


        
        #print(X_l2h.shape,"~~~~",X_h2h.shape)
        X_h = X_h2h + X_l12h+X_l22h+X_l32h
        X_l1 = X_h2l1+X_l12l1+X_l22l1+X_l32l1
        X_l2=X_h2l2+X_l12l2+X_l22l2+X_l32l2
        X_l3=X_h2l3+X_l12l3+X_l22l3+X_l32l3
 
        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l1 = self.n_l(X_l1)
            X_l2 = self.n_l(X_l2)
            X_l3 = self.n_l(X_l3)
 
        if self.a:
            X_h = self.a(X_h)
            X_l1 = self.a(X_l1)
            X_l2 = self.a(X_l2)
            X_l3 = self.a(X_l3)
 
        return X_h, X_l1,X_l2,X_l3
 
 
class M_NP_FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(M_NP_FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2g_pool3 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.stride = stride
        ###low frequency
        self.h2l2 = nn.Conv2d(in_nc, int(0.5*alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l3 = nn.Conv2d(in_nc, int(0.5*alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        ###high frequency
        self.h2h = nn.Conv2d(in_nc, int(0.5*out_nc - int(0.5*alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l1 = nn.Conv2d(in_nc, int(0.5*out_nc - int(0.5*alpha * out_nc)),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
 
    def forward(self, x):
        #if self.stride ==2:
            #x = self.h2g_pool(x)
 
        X_h = self.h2h(x)
        X_l1=self.h2l1(self.h2g_pool(x))
        X_l2=self.h2l2(self.h2g_pool2(self.h2g_pool(x)))
        X_l3=self.h2l3(self.h2g_pool3(self.h2g_pool2(self.h2g_pool(x))))
        #X_l = self.h2l(self.h2g_pool2(self.h2g_pool(x)))#double pool
 
        if self.n_h and self.n_l:##batch norm
            X_h = self.n_h(X_h)
            X_l1 = self.n_l(X_l1)
            X_l2 = self.n_l(X_l2)
            X_l3 = self.n_l(X_l3)
 
        if self.a:#Activation layer
            X_h = self.a(X_h)
            X_l1 = self.a(X_l1)
            X_l2 = self.a(X_l2)
            X_l3 = self.a(X_l3)
 
        return X_h, X_l1,X_l2,X_l3
 
 
class M_NP_LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(M_NP_LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        # self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2)
        # self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')## pool
        self.upsample2 = nn.Upsample(scale_factor=4, mode='nearest')##double pool
        self.upsample3 = nn.Upsample(scale_factor=8, mode='nearest')##triple pool
        self.stride = stride
 
        self.l22h = nn.Conv2d(int(0.5*int(alpha * in_nc)), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l32h = nn.Conv2d(int(0.5*int(alpha * in_nc)), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l12h = nn.Conv2d(int(0.5*(in_nc - int(alpha * in_nc))), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
 
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None
 
    def forward(self, x):
        X_h, X_l1,X_l2,X_l3 = x
 
        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        
        X_h2h = self.h2h(X_h)
        X_l12h = self.upsample(self.l12h(X_l1))
        X_l22h = self.upsample2(self.l22h(X_l2))
        X_l32h = self.upsample3(self.l32h(X_l3))
        #X_l2h = self.l2h(X_l)
        
        X_h = X_h2h + X_l12h+X_l22h+X_l32h
 
        if self.n_h:
            X_h = self.n_h(X_h)
 
        if self.a:
            X_h = self.a(X_h)
 
        return X_h


class M_NP_OctaveResBlock(nn.Module):
    '''
    ResNet Block, 3-3 style
    with extra residual scaling used in EDSR
    (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
    '''
    def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, alpha=0.75, stride=1, dilation=1, groups=1, \
            bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
        super(M_NP_OctaveResBlock, self).__init__()
        conv0 = M_NP_OctaveConv(in_nc, mid_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
            norm_type, act_type, mode)
        if mode == 'CNA':
            act_type = None
        if mode == 'CNAC':  # Residual path: |-CNAC-|
            act_type = None
            norm_type = None
        conv1 = M_NP_OctaveConv(mid_nc, out_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
            norm_type, act_type, mode)

        self.res = sequential(conv0, conv1)
        self.res_scale = res_scale

    def forward(self, x):
        #if(len(x)>2):
            #print(x[0].shape,"  ",x[1].shape,"  ",x[2].shape,"  ",x[3].shape)
        #print(len(x))
        res = self.res(x)
        res = (res[0].mul(self.res_scale), res[1].mul(self.res_scale),res[2].mul(self.res_scale),res[3].mul(self.res_scale))
        x = (x[0] + res[0], x[1] + res[1],x[2] + res[2],x[3] + res[3])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x

 

 

 

 

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值