实验笔记之——基于DWT的octave layer(DWT在pytorch中实现)

之前的博文《论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现》曾经研究过WMCNN。本博文就是采用DWT变换代替octave中的pooling

代码

类似于博文《实验笔记之——octave conv (without pooling)》对octave layer的结构进行改进如下:

pytorch中实现离散小波变换

https://github.com/fbcotter/pytorch_wavelets

git clone https://github.com/fbcotter/pytorch_wavelets

cd pytorch_wavelets
pip install .
pip install -r tests/requirements.txt

测试

改修代码如下:

##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
        self.ifm = DTCWTInverse( biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride

        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None


    def forward(self, x):
        X_h, X_l = x

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        X_l,X_ll=self.xfm(X_l)

        X_h2h = self.h2h(X_h)
        #X_l2h = self.upsample(self.l2h(X_l))
        #X_l2h = self.l2h(X_l)
        #=self.l2h(X_l)
        # print(X_l.shape,"~~~~",X_ll[0].shape,X_ll[1].shape,X_ll[2].shape)
        # exit()
        X_l2ha = self.ifm((X_l,X_ll))
        X_l2h = self.l2h(X_l2ha)
        # print(X_l.shape,"~~~~",X_ll[0].shape)
        # exit()
        # X_l2h = self.ifm((X_l,X_ll))


        X_l2l = self.l2l(X_l2ha)
        X_h2l = self.h2l(X_h)

        X_h2l,X_h2ll=self.xfm(X_h2l)
        X_l2l,X_lla=self.xfm(X_l2l)
        
        #print(X_lla[0].shape,"~~~~",X_h2ll[0].shape)
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l
        # print(X_lla[0].shape,"~~~~",X_h2ll[0].shape)
        #exit()
        X_ll[0]=X_lla[0]+X_h2ll[0]
        X_ll[1]=X_lla[1]+X_h2ll[1]
        X_ll[2]=X_lla[2]+X_h2ll[2]

        X_l=self.ifm((X_l,X_ll))

        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l


class DWT_FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        #self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride
        ###low frequency
        self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        ###high frequency
        self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        #if self.stride ==2:
            #x = self.h2g_pool(x)

        X_h = self.h2h(x)
        X_l = self.h2l(x)



        if self.n_h and self.n_l:##batch norm
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)


        if self.a:#Activation layer
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l


class DWT_LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        #self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride

        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None

    def forward(self, x):
        X_h, X_l = x

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        
        X_h2h = self.h2h(X_h)
        X_l2h=self.l2h(X_l)
        #X_l2h = self.l2h(X_l)
        
        X_h = X_h2h + X_l2h

        if self.n_h:
            X_h = self.n_h(X_h)

        if self.a:
            X_h = self.a(X_h)

        return X_h

class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
    '''
    Residual Dense Block
    style: 4 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    '''

    def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=last_act, mode=mode)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),(torch.cat((x[1], x1[1]), dim=1))))
        x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),(torch.cat((x[1], x1[1],x2[1]), dim=1))))
        x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),(torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1))))

        res = (x4[0].mul(0.2), x4[1].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x


class DWT_octave_RRDBTiny(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_RRDBTiny, self).__init__()
        self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)

        res = (out[0].mul(0.2), out[1].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x
##################this is ESRGAN based on DWT_octave
class DWT_Octave_RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32,alpha=0.125, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
        super(DWT_Octave_RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv1 = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
        fea_conv = B.DWT_FirstOctaveConv(nf, nf, kernel_size=3,alpha=alpha, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
        rb_blocks = [B.DWT_octave_RRDBTiny(nf, kernel_size=3, gc=32,alpha=alpha,stride=1, bias=True, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
        LR_conv = B.DWT_LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv1,B.ShortcutBlock(B.sequential(fea_conv,*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x
##############################################################################################

 

实验

 

 

 

改进

通过测试有新发现

https://blog.csdn.net/qq_40587575/article/details/83154042

从上面测试可以看出,只需要J=1

改进代码如下:

##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        #self.xfm = DWTForward(J=1, wave='db3', mode='zero')
        #self.ifm = DWTInverse(wave='db3', mode='zero')
        self.stride = stride

        # self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
        #                         kernel_size, 1, padding, dilation, groups, bias)
        # self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
        #                         kernel_size, 1, padding, dilation, groups, bias)
        # self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
        #                         kernel_size, 1, padding, dilation, groups, bias)
        # self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
        #                         kernel_size, 1, padding, dilation, groups, bias)
        self.l2l = nn.Conv2d(in_nc, out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(in_nc, out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc, out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc, out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None


    def forward(self, x):
        X_ll,X_lh,X_hl,X_hh = x

        # print(X_ll.shape,'~~~',X_lh.shape,'~~~',X_lh.shape,'~~~',X_lh.shape)
        # exit()
        # A,B=self.xfm(x)
        # X_ll=A
        # X_lh=B[0][:,:,0]
        # X_hl=B[0][:,:,1]
        # X_hh=B[0][:,:,2]

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_hh2h = self.h2h(X_hh)
        #X_l2h = self.upsample(self.l2h(X_l))
        X_lh2h = self.l2h(X_lh)
        #X_l2h = self.upsample(self.l2h(X_l))

        X_ll2l = self.l2l(X_ll)
        #X_h2l = self.h2l(self.h2g_pool(X_h))
        
        X_hl2l = self.h2l(X_hl)
        #X_h2l = self.h2l(self.h2g_pool2(self.h2g_pool(X_h)))
        
        #print(X_l2h.shape,"~~~~",X_h2h.shape)
        X_hh=X_hh2h
        X_lh=X_lh2h
        X_ll=X_ll2l
        X_hl=X_hl2l

        if self.n_h and self.n_l:
            X_hh= self.n_h(X_hh)
            X_hh=self.n_h(X_hh)
            X_lh=self.n_h(X_lh)
            X_ll=self.n_h(X_ll)
            X_hl=self.n_h(X_hl)

        if self.a:
            X_hh = self.a(X_hh)
            X_hh=self.a(X_hh)
            X_lh=self.a(X_lh)
            X_ll=self.a(X_ll)
            X_hl=self.a(X_hl)

        # A=X_ll
        # B[0][:,:,0]=X_lh
        # B[0][:,:,1]=X_hl
        # B[0][:,:,2]=X_hh
        # x=ifm((A,B))

        return X_ll,X_lh,X_hl,X_hh




class DWT_FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
        #self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride
        ###low frequency
        self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        ###high frequency
        self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        #if self.stride ==2:
            #x = self.h2g_pool(x)

        #X_h = self.h2h(x)
        #X_l = self.h2l(x)
        A,B=self.xfm(x)
        X_ll=A
        X_lh=B[0][:,:,0]
        X_hl=B[0][:,:,1]
        X_hh=B[0][:,:,2]


        # if self.n_h and self.n_l:##batch norm
        #     X_h = self.n_h(X_h)
        #     X_l = self.n_l(X_l)


        # if self.a:#Activation layer
        #     X_h = self.a(X_h)
        #     X_l = self.a(X_l)

        return X_ll,X_lh,X_hl,X_hh


class DWT_LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        #self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
        self.stride = stride

        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None

    def forward(self, x):
        X_ll,X_lh,X_hl,X_hh = x

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        A=X_ll
        c= A.shape
        C = torch.randn(c[0], c[1], 3, c[-2], c[-1])
        #X_lh_ = torch.unsqueeze(X_lh, 2)
        #X_hl_ = torch.unsqueeze(X_hl, 2)
        #X_hh_ = torch.unsqueeze(X_hh, 2)
        #C=torch.cat((X_lh_,X_hl_,X_hh_), dim=2)
        C[:,:,0]=X_lh
        C[:,:,1]=X_hl
        C[:,:,2]=X_hh
        #C=C.cpu()
        C_ = [C.cuda()]
        #A=A.cpu()
        X_h=self.ifm((A,C_))
        # print(X_h.shape)
        # exit()

        # X_h2h = self.h2h(X_h)
        # X_l2h=self.l2h(X_l)
        # #X_l2h = self.l2h(X_l)
        
        # X_h = X_h2h + X_l2h

        # if self.n_h:
        #     X_h = self.n_h(X_h)

        # if self.a:
        #     X_h = self.a(X_h)

        return X_h

class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
    '''
    Residual Dense Block
    style: 4 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    '''

    def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=last_act, mode=mode)

    def forward(self, x):
        # print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)
        # exit()
        x1 = self.conv1(x)
        x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1),torch.cat((x[2], x1[2]), dim=1),torch.cat((x[3], x1[3]), dim=1)))
        x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1),torch.cat((x[2], x1[2],x2[2]), dim=1),torch.cat((x[3], x1[3],x2[3]), dim=1)))
        x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1),torch.cat((x[2], x1[2],x2[2],x3[2]), dim=1),torch.cat((x[3], x1[3],x2[3],x3[3]), dim=1)))


        res = (x4[0].mul(0.2), x4[1].mul(0.2),x4[2].mul(0.2),x4[3].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x


class DWT_octave_RRDBTiny(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_RRDBTiny, self).__init__()
        self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)

        res = (out[0].mul(0.2), out[1].mul(0.2),out[2].mul(0.2),out[3].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x

实验结果:

 

改进2

##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
        self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
        self.stride = stride

        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None


    def forward(self, x):
        X_h, X_l = x

        # print(X_ll.shape,'~~~',X_lh.shape,'~~~',X_lh.shape,'~~~',X_lh.shape)
        # exit()
        # A,B=self.xfm(x)
        # X_ll=A
        # X_lh=B[0][:,:,0]
        # X_hl=B[0][:,:,1]
        # X_hh=B[0][:,:,2]

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_h2h = self.h2h(X_h)
        #X_l2h = self.upsample(self.l2h(X_l))
        X_l2h = self.l2h(X_l)
        #X_l2h = self.upsample(self.l2h(X_l))

        #DWT for X_h
        A,B=self.xfm(X_h)
        X_hll=A
        X_hlh=B[0][:,:,0]
        X_hhl=B[0][:,:,1]
        X_hhh=B[0][:,:,2]

        #transfer
        X_hll2l=self.h2l(X_hll)
        X_hlh2l=self.h2l(X_hlh)
        X_hhl2l=self.h2l(X_hhl)
        X_hhh2l=self.h2l(X_hhh)

        #DWT for X_l

        C,D=self.xfm(X_l)
        X_lll=C
        X_llh=D[0][:,:,0]
        X_lhl=D[0][:,:,1]
        X_lhh=D[0][:,:,2]

        #transfer

        X_lll2l=self.l2l(X_lll)
        X_llh2l=self.l2l(X_llh)
        X_lhl2l=self.l2l(X_lhl)
        X_lhh2l=self.l2l(X_lhh)


        #X_ll2l = self.l2l(X_ll)
        #X_h2l = self.h2l(self.h2g_pool(X_h))
        
        #X_hl2l = self.h2l(X_hl)
        #X_h2l = self.h2l(self.h2g_pool2(self.h2g_pool(X_h)))
        
        #print(X_l2h.shape,"~~~~",X_h2h.shape)
        X_h=X_h2h+X_l2h



        E=X_lll2l+X_hll2l
        f= E.shape
        F = torch.randn(f[0], f[1], 3, f[-2], f[-1])

        F[:,:,0]=X_llh2l+X_hlh2l
        F[:,:,1]=X_lhl2l+X_hhl2l
        F[:,:,2]=X_lhh2l+X_hhh2l
        F_ = [F.cuda()]

        X_l=self.ifm((E,F_))


        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)
 
        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)


        return X_h, X_l




class DWT_FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
        #self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride
        ###low frequency
        self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        ###high frequency
        self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        #if self.stride ==2:
            #x = self.h2g_pool(x)

        X_h = self.h2h(x)
        X_l = self.h2l(x)


        if self.n_h and self.n_l:##batch norm
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)


        if self.a:#Activation layer
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h,X_l


class DWT_LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        #self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        #self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
        self.stride = stride

        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None

    def forward(self, x):
        X_h,X_l = x

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_h2h = self.h2h(X_h)
        X_l2h=self.l2h(X_l)
        #X_l2h = self.l2h(X_l)
        
        X_h = X_h2h + X_l2h

        if self.n_h:
            X_h = self.n_h(X_h)

        if self.a:
            X_h = self.a(X_h)

        return X_h

class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
    '''
    Residual Dense Block
    style: 4 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    '''

    def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=last_act, mode=mode)

    def forward(self, x):
        # print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)
        # exit()
        x1 = self.conv1(x)
        x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1)))
        x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1)))
        x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1)))


        res = (x4[0].mul(0.2), x4[1].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x


class DWT_octave_RRDBTiny(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_RRDBTiny, self).__init__()
        self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)

        res = (out[0].mul(0.2), out[1].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x

 

结果

 

改进3

##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
        self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
        self.stride = stride

        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None


    def forward(self, x):
        X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh = x

        #for X_h to h
        X_h2h = self.h2h(X_h)
        #X_l2h = self.upsample(self.l2h(X_l))
        #get X_l
        H=X_l_ll
        j= H.shape
        J = torch.randn(j[0], j[1], 3, j[-2], j[-1])
        J[:,:,0]=X_l_lh
        J[:,:,1]=X_l_hl
        J[:,:,2]=X_l_hh
        J_ = [J.cuda()]
        X_l=self.ifm((H,J_))
        #X_l to h
        X_l2h = self.l2h(X_l)


        #DWT for X_h
        A,B=self.xfm(X_h)
        X_hll=A
        X_hlh=B[0][:,:,0]
        X_hhl=B[0][:,:,1]
        X_hhh=B[0][:,:,2]

        #transfer
        X_hll2l=self.h2l(X_hll)
        X_hlh2l=self.h2l(X_hlh)
        X_hhl2l=self.h2l(X_hhl)
        X_hhh2l=self.h2l(X_hhh)

        #for X_l series (X_l_ll,X_l_lh,X_l_hl,X_l_hh)
        #transfer

        X_lll2l=self.l2l(X_l_ll)
        X_llh2l=self.l2l(X_l_lh)
        X_lhl2l=self.l2l(X_l_hl)
        X_lhh2l=self.l2l(X_l_hh)


        #for X_h
        X_h=X_h2h+X_l2h

        #for X_l series (X_l_ll,X_l_lh,X_l_hl,X_l_hh)

        X_l_ll=X_lll2l+X_hll2l
        X_l_lh=X_llh2l+X_hlh2l
        X_l_hl=X_lhl2l+X_hhl2l
        X_l_hh=X_lhh2l+X_hhh2l



        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            #X_l = self.n_l(X_l)
            X_l_ll = self.n_l(X_l_ll)
            X_l_lh = self.n_l(X_l_lh)
            X_l_hl = self.n_l(X_l_hl)
            X_l_hh = self.n_l(X_l_hh)
 
        if self.a:
            X_h = self.a(X_h)
            #X_l = self.a(X_l)
            X_l_ll = self.a(X_l_ll)
            X_l_lh = self.a(X_l_lh)
            X_l_hl = self.a(X_l_hl)
            X_l_hh = self.a(X_l_hh)


        return X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh



class DWT_FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
        #self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride
        ###low frequency
        self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        ###high frequency
        self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        #if self.stride ==2:
            #x = self.h2g_pool(x)

        X_h = self.h2h(x)
        X_l = self.h2l(x)

        A,B=self.xfm(X_l)
        X_l_ll=A
        X_l_lh=B[0][:,:,0]
        X_l_hl=B[0][:,:,1]
        X_l_hh=B[0][:,:,2]


        if self.n_h and self.n_l:##batch norm
            X_h = self.n_h(X_h)
            #X_l = self.n_l(X_l)
            X_l_ll = self.n_l(X_l_ll)
            X_l_lh = self.n_l(X_l_lh)
            X_l_hl = self.n_l(X_l_hl)
            X_l_hh = self.n_l(X_l_hh)


        if self.a:#Activation layer
            X_h = self.a(X_h)
            #X_l = self.a(X_l)
            X_l_ll = self.a(X_l_ll)
            X_l_lh = self.a(X_l_lh)
            X_l_hl = self.a(X_l_hl)
            X_l_hh = self.a(X_l_hh)

        return X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh


class DWT_LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        #self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
        self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
        self.stride = stride

        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None

    def forward(self, x):
        X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh = x

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        A=X_l_ll
        c= A.shape
        C = torch.randn(c[0], c[1], 3, c[-2], c[-1])
        #X_lh_ = torch.unsqueeze(X_lh, 2)
        #X_hl_ = torch.unsqueeze(X_hl, 2)
        #X_hh_ = torch.unsqueeze(X_hh, 2)
        #C=torch.cat((X_lh_,X_hl_,X_hh_), dim=2)
        C[:,:,0]=X_l_lh
        C[:,:,1]=X_l_hl
        C[:,:,2]=X_l_hh
        #C=C.cpu()
        C_ = [C.cuda()]
        #A=A.cpu()
        X_l=self.ifm((A,C_))


        X_h2h = self.h2h(X_h)
        X_l2h=self.l2h(X_l)
        #X_l2h = self.l2h(X_l)
        
        X_h = X_h2h + X_l2h

        if self.n_h:
            X_h = self.n_h(X_h)

        if self.a:
            X_h = self.a(X_h)

        return X_h

class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
    '''
    Residual Dense Block
    style: 4 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    '''

    def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=last_act, mode=mode)

    def forward(self, x):
        # print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)
        # exit()
        x1 = self.conv1(x)
        x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1),torch.cat((x[2], x1[2]), dim=1),torch.cat((x[3], x1[3]), dim=1),torch.cat((x[4], x1[4]), dim=1)))
        x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1),torch.cat((x[2], x1[2],x2[2]), dim=1),torch.cat((x[3], x1[3],x2[3]), dim=1),torch.cat((x[4], x1[4],x2[4]), dim=1)))
        x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1),torch.cat((x[2], x1[2],x2[2],x3[2]), dim=1),torch.cat((x[3], x1[3],x2[3],x3[3]), dim=1),torch.cat((x[4], x1[4],x2[4],x3[4]), dim=1)))
 
 
        res = (x4[0].mul(0.2), x4[1].mul(0.2),x4[2].mul(0.2),x4[3].mul(0.2),x4[4].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3], x[4] + res[4])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))
 
        #return (x[0] + res[0], x[1]+res[1])
        return x



class DWT_octave_RRDBTiny(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_RRDBTiny, self).__init__()
        self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)

        res = (out[0].mul(0.2), out[1].mul(0.2),out[2].mul(0.2),out[3].mul(0.2),out[4].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3],x[4] + res[4])

        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x

结果

 

 

 

 

  • 5
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
要使用pytorch的离散小波变换(Discrete Wavelet Transform, DWT),你可以按照以下步骤进行操作: 1. 首先,确保你已经安装了pytorch_wavelets库。你可以使用以下命令来安装: cd pytorch_wavelets pip install . 2. 接下来,你还需要安装测试所需的依赖库,可以使用以下命令来安装: pip install -r tests/requirements.txt 3. 如果你还没有下载pytorch_wavelets库,你可以通过以下命令将其克隆到本地: git clone https://github.com/fbcotter/pytorch_wavelets 通过以上步骤,你就可以开始使用pytorch_wavelets库进行离散小波变换了。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [实验笔记之——基于DWToctave layerDWTpytorch实现)](https://blog.csdn.net/gwplovekimi/article/details/90169433)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [PyTorch学习笔记——PyTorch模块和基础实战](https://blog.csdn.net/qq_56551150/article/details/125452424)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值