实验笔记之——基于DWT的octave layer(DWT在pytorch中实现)

之前的博文《论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现》曾经研究过WMCNN。本博文就是采用DWT变换代替octave中的pooling


类似于博文《实验笔记之——octave conv (without pooling)》对octave layer的结构进行改进如下:



git clone https://github.com/fbcotter/pytorch_wavelets

cd pytorch_wavelets
pip install .
pip install -r tests/requirements.txt



#DWT octave
# Block for OctConv
class DWT_OctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_OctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
        self.ifm = DTCWTInverse( biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride

        self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        X_h, X_l = x

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_h2h = self.h2h(X_h)
        #X_l2h = self.upsample(self.l2h(X_l))
        #X_l2h = self.l2h(X_l)
        # print(X_l.shape,"~~~~",X_ll[0].shape,X_ll[1].shape,X_ll[2].shape)
        # exit()
        X_l2ha = self.ifm((X_l,X_ll))
        X_l2h = self.l2h(X_l2ha)
        # print(X_l.shape,"~~~~",X_ll[0].shape)
        # exit()
        # X_l2h = self.ifm((X_l,X_ll))

        X_l2l = self.l2l(X_l2ha)
        X_h2l = self.h2l(X_h)

        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l
        # print(X_lla[0].shape,"~~~~",X_h2ll[0].shape)


        if self.n_h and self.n_l:
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l

class DWT_FirstOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_FirstOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        #self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        #self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride
        ###low frequency
        self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        ###high frequency
        self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
                                kernel_size, 1, padding, dilation, groups, bias)
        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
        self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None

    def forward(self, x):
        #if self.stride ==2:
            #x = self.h2g_pool(x)

        X_h = self.h2h(x)
        X_l = self.h2l(x)

        if self.n_h and self.n_l:##batch norm
            X_h = self.n_h(X_h)
            X_l = self.n_l(X_l)

        if self.a:#Activation layer
            X_h = self.a(X_h)
            X_l = self.a(X_l)

        return X_h, X_l

class DWT_LastOctaveConv(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(DWT_LastOctaveConv, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        #self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        #self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
        #self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
        self.stride = stride

        self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
                                kernel_size, 1, padding, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None

    def forward(self, x):
        X_h, X_l = x

        #if self.stride ==2:
            #X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        X_h2h = self.h2h(X_h)
        #X_l2h = self.l2h(X_l)
        X_h = X_h2h + X_l2h

        if self.n_h:
            X_h = self.n_h(X_h)

        if self.a:
            X_h = self.a(X_h)

        return X_h

class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
    Residual Dense Block
    style: 4 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)

    def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
            last_act = act_type
        self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
             norm_type=norm_type, act_type=act_type, mode=mode) 
        # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
        #     norm_type=norm_type, act_type=last_act, mode=mode)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),(torch.cat((x[1], x1[1]), dim=1))))
        x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),(torch.cat((x[1], x1[1],x2[1]), dim=1))))
        x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),(torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1))))

        res = (x4[0].mul(0.2), x4[1].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x

class DWT_octave_RRDBTiny(nn.Module):
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)

    def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(DWT_octave_RRDBTiny, self).__init__()
        self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)

        res = (out[0].mul(0.2), out[1].mul(0.2))
        x = (x[0] + res[0], x[1] + res[1])
        #print(len(x),"~~~",len(res),"~~~",len(x + res))

        #return (x[0] + res[0], x[1]+res[1])
        return x
##################this is ESRGAN based on DWT_octave
class DWT_Octave_RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32,alpha=0.125, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
        super(DWT_Octave_RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv1 = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
        fea_conv = B.DWT_FirstOctaveConv(nf, nf, kernel_size=3,alpha=alpha, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
        rb_blocks = [B.DWT_octave_RRDBTiny(nf, kernel_size=3, gc=32,alpha=alpha,stride=1, bias=True, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
        LR_conv = B.DWT_LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv1,B.ShortcutBlock(B.sequential(fea_conv,*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x











要使用pytorch的离散小波变换(Discrete Wavelet Transform, DWT),你可以按照以下步骤进行操作: 1. 首先,确保你已经安装了pytorch_wavelets库。你可以使用以下命令来安装: cd pytorch_wavelets pip install . 2. 接下来,你还需要安装测试所需的依赖库,可以使用以下命令来安装: pip install -r tests/requirements.txt 3. 如果你还没有下载pytorch_wavelets库,你可以通过以下命令将其克隆到本地: git clone https://github.com/fbcotter/pytorch_wavelets 通过以上步骤,你就可以开始使用pytorch_wavelets库进行离散小波变换了。


