【cvpr2022】AP-BSN网络解读

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

AP-BSN网络是在BSN网络基础上改进了shuffle机制,打破了自然图像中噪声具有空间连续性的特点。该文章主要创新点有三:

  • 提出了非对称的pixshuffle-downsample(AP);
  • 采取了随即替换的精细化方法;
  • 是第一个面向真实世界图像去噪的自监督网络。

一、BSN 和 Pisshuffle Downsample

盲点网络(Blind-spot network,BSN)是卷积神经网络的一种变种,如图所示,该变种将卷积核的中心点置零,即作为一种无视中心像素的网络。

盲点示意图(一维)

该网络中需要满足两条假设才可以进行训练:

  • 假设一:噪声在空间上是像素级的;(这一点在真实图像中并不满足)
  • 假设二:噪声是独立且零均值的 (这一点在深空探测中并不满足)
    由于实际空间中多数图像的噪声都是空间相关的,SIDD和DND数据集中的噪声都说明了该问题,因此假设一并不被满足,为此有学者设计了一种下采样方式,通俗得理解为等间隔取点,或者以不同的相位进行多组池化,如图所示。不同步长的pixshuffle downsampling

这种操作带来了两个影响:打破了空间相关性和造成了伪影,即右下角绿框图像中红色箭头所指的位置本是一个木杆,但是在PD过程中成为一个孤立的像元,BSN的推理过程便会将该位置识别为噪声进行抑制。

二、改进思路

1、非对称下采样

非对称下采样即使名称中AP得由来,也是该文章最重要得创新点。简单讲,即是在训练过程中使用 P D 5 PD_5 PD5打破空间相关性,在推理过程中选用 P D 2 PD_2 PD2尽可能地抑制伪影得产生。
PD得代码为:

def pixel_shuffle_down_sampling(x:torch.Tensor, f:int, pad:int=0, pad_value:float=0.):
    '''
    pixel-shuffle down-sampling (PD) from "When AWGN-denoiser meets real-world noise." (AAAI 2019)
    Args:
        x (Tensor) : input tensor
        f (int) : factor of PD
        pad (int) : number of pad between each down-sampled images
        pad_value (float) : padding value
    Return:
        pd_x (Tensor) : down-shuffled image tensor with pad or not
    '''
    # single image tensor
    if len(x.shape) == 3:
        c,w,h = x.shape
        unshuffled = F.pixel_unshuffle(x, f)
        if pad != 0: unshuffled = F.pad(unshuffled, (pad, pad, pad, pad), value=pad_value)
        return unshuffled.view(c,f,f,w//f+2*pad,h//f+2*pad).permute(0,1,3,2,4).reshape(c, w+2*f*pad, h+2*f*pad)
    # batched image tensor
    else:
        b,c,w,h = x.shape
        unshuffled = F.pixel_unshuffle(x, f)
        if pad != 0: unshuffled = F.pad(unshuffled, (pad, pad, pad, pad), value=pad_value)
        return unshuffled.view(b,c,f,f,w//f+2*pad,h//f+2*pad).permute(0,1,2,4,3,5).reshape(b,c,w+2*f*pad, h+2*f*pad)

def pixel_shuffle_up_sampling(x:torch.Tensor, f:int, pad:int=0):
    '''
    inverse of pixel-shuffle down-sampling (PD)
    see more details about PD in pixel_shuffle_down_sampling()
    Args:
        x (Tensor) : input tensor
        f (int) : factor of PD
        pad (int) : number of pad will be removed
    '''
    # single image tensor
    if len(x.shape) == 3:
        c,w,h = x.shape
        before_shuffle = x.view(c,f,w//f,f,h//f).permute(0,1,3,2,4).reshape(c*f*f,w//f,h//f)
        if pad != 0: before_shuffle = before_shuffle[..., pad:-pad, pad:-pad]
        return F.pixel_shuffle(before_shuffle, f)   
    # batched image tensor
    else:
        b,c,w,h = x.shape
        before_shuffle = x.view(b,c,f,w//f,f,h//f).permute(0,1,2,4,3,5).reshape(b,c*f*f,w//f,h//f)
        if pad != 0: before_shuffle = before_shuffle[..., pad:-pad, pad:-pad]
        return F.pixel_shuffle(before_shuffle, f)

2、随即替换优化(R3)

最经典的图像去噪算法即是均值滤波。作为一种低通滤波器,其最主要的影响便是抑制了图像中的高频部分。无论算法如何发展,最重要的目的依然是保留有结构性的高频信息。有方法利用固定的人早点去保留结构性的高频信息,本文提出采用随即的mask去组合图像与BSN网络处理后的图像,并进行多个模板的累加,从而实现某些高频信息的保留。R3示意图

I M i = M i ⋅ I N o i s y + ( 1 − M i ) ⋅ I B S N s I_{M_{i}}=M_{i}·I_{Noisy}+(1-M_{i})·I^{s}_{BSN} IMi=MiINoisy+(1Mi)IBSNs

代码如下(示例):

    def forward(self, img, pd=None):
        '''
        Foward function includes sequence of PD, BSN and inverse PD processes.
        Note that denoise() function is used during inference time (for differenct pd factor and R3).
        '''
        # default pd factor is training factor (a)
        if pd is None: pd = self.pd_a

        # do PD
        if pd > 1:
            pd_img = pixel_shuffle_down_sampling(img, f=pd, pad=self.pd_pad)
        else:
            p = self.pd_pad
            pd_img = F.pad(img, (p,p,p,p))
        
        # forward blind-spot network
        pd_img_denoised = self.bsn(pd_img)

        # do inverse PD
        if pd > 1:
            img_pd_bsn = pixel_shuffle_up_sampling(pd_img_denoised, f=pd, pad=self.pd_pad)
        else:
            p = self.pd_pad
            img_pd_bsn = pd_img_denoised[:,:,p:-p,p:-p]

        return img_pd_bsn

    def denoise(self, x):
        '''
        Denoising process for inference.
        '''
        b,c,h,w = x.shape

        # pad images for PD process
        if h % self.pd_b != 0:
            x = F.pad(x, (0, 0, 0, self.pd_b - h%self.pd_b), mode='constant', value=0)
        if w % self.pd_b != 0:
            x = F.pad(x, (0, self.pd_b - w%self.pd_b, 0, 0), mode='constant', value=0)

        # forward PD-BSN process with inference pd factor
        img_pd_bsn = self.forward(img=x, pd=self.pd_b)

        # Random Replacing Refinement
        if not self.R3:
            ''' Directly return the result (w/o R3) '''
            return img_pd_bsn[:,:,:h,:w]
        else:
            denoised = torch.empty(*(x.shape), self.R3_T, device=x.device)
            for t in range(self.R3_T):
                indice = torch.rand_like(x)
                mask = indice < self.R3_p

                tmp_input = torch.clone(img_pd_bsn).detach()
                tmp_input[mask] = x[mask]
                p = self.pd_pad
                tmp_input = F.pad(tmp_input, (p,p,p,p), mode='reflect')
                if self.pd_pad == 0:
                    denoised[..., t] = self.bsn(tmp_input)
                else:
                    denoised[..., t] = self.bsn(tmp_input)[:,:,p:-p,p:-p]

            return torch.mean(denoised, dim=-1)

需要注意的是R3算法仅在推理过程中应用


总结

该文章发表于2022年的CVPR,其创新点在于首次实现了现实噪声图像的自监督训练,其余两个创新点均是在应用和方法上进行的小的改进,但是该文章的优点在于对理论和故事讲得非常完善,该文章来自韩国首尔大学,CVPR2023中另有一篇文章在此文章上进行改进,是下一篇的学习目标。另外该文章代码较为规范,内容比较充实,可以作为基础框架使用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值