APbsn

AP-BSN:

Self-Supervised Denoising for Real-World Images via Asymmetric PD and Blind-Spot Network

1.图像噪声的相关性

bsn 盲点网络 主要用于自监督,空间不相关的噪声。
现实中的图像 噪声一般是相关的。

比如下图:相邻像素的噪声相关性是比较高的
distance = 5之后 相关度就比较低了。
所以制作好自己的数据集之后要看一下空间相关度
在这里插入图片描述

2.bsn盲点网路

有理论证明:
在这里插入图片描述

但是前提是 像素无关的噪声。

3.pixel subsample方法

就是间隔取值去空间相关性
在这里插入图片描述

但是有个问题, 下采样间隔越大,相关性越低,越满足噪声在图像空间上独立不相关
但是下采样间隔越大 同时会引入 混淆问题Aliasing artifacts。

利用pd stride=2或者=5的时候效果都不太好,下图中间2张图:
最右下角是 训练使用5, infer使用2,效果好了很多。
在这里插入图片描述

4.Asymmetric PD (APa/b).

We note that a and b are stride factors for training and inference, respectively.

就是 训练的时候 stride=5 去相关
推理的时候 stride=2避免aliasing

在这里插入图片描述

5.random-replaceing refinement

就是 随机选取 原始噪声图像中的一些像素 替换 降噪后的对应像素,得到T个新的图像

然后输入到 降噪网络中 得到T个降噪图像,求平均
在这里插入图片描述

6.小结

本文主要就是提出4,5来改进 blind spot net

By default, we adopt AP5/2
and set p and T to 0.16 and 8

code: https://github.com/wooseoklee4/AP-BSN
net:
在这里插入图片描述

整体网络的结构还是比较清晰的。
如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..util.util import pixel_shuffle_down_sampling, pixel_shuffle_up_sampling
from . import regist_model
from .DBSNl import DBSNl


@regist_model
class APBSN(nn.Module):
    '''
    Asymmetric PD Blind-Spot Network (AP-BSN)
    '''
    def __init__(self, pd_a=5, pd_b=2, pd_pad=2, R3=True, R3_T=8, R3_p=0.16, 
                    bsn='DBSNl', in_ch=3, bsn_base_ch=128, bsn_num_module=9):
        '''
        Args:
            pd_a           : 'PD stride factor' during training
            pd_b           : 'PD stride factor' during inference
            pd_pad         : pad size between sub-images by PD process
            R3             : flag of 'Random Replacing Refinement'
            R3_T           : number of masks for R3
            R3_p           : probability of R3
            bsn            : blind-spot network type
            in_ch          : number of input image channel
            bsn_base_ch    : number of bsn base channel
            bsn_num_module : number of module
        '''
        super().__init__()

        # network hyper-parameters
        self.pd_a    = pd_a
        self.pd_b    = pd_b
        self.pd_pad  = pd_pad
        self.R3      = R3
        self.R3_T    = R3_T
        self.R3_p    = R3_p
        
        # define network
        if bsn == 'DBSNl':
            self.bsn = DBSNl(in_ch, in_ch, bsn_base_ch, bsn_num_module)
        else:
            raise NotImplementedError('bsn %s is not implemented'%bsn)

    def forward(self, img, pd=None):
        '''
        Foward function includes sequence of PD, BSN and inverse PD processes.
        Note that denoise() function is used during inference time (for differenct pd factor and R3).
        '''
        # default pd factor is training factor (a)
        if pd is None: pd = self.pd_a

        # do PD
        if pd > 1:
            pd_img = pixel_shuffle_down_sampling(img, f=pd, pad=self.pd_pad)
        else:
            p = self.pd_pad
            pd_img = F.pad(img, (p,p,p,p))
        
        # forward blind-spot network
        pd_img_denoised = self.bsn(pd_img)

        # do inverse PD
        if pd > 1:
            img_pd_bsn = pixel_shuffle_up_sampling(pd_img_denoised, f=pd, pad=self.pd_pad)
        else:
            p = self.pd_pad
            img_pd_bsn = pd_img_denoised[:,:,p:-p,p:-p]

        return img_pd_bsn

    def denoise(self, x):
        '''
        Denoising process for inference.
        '''
        b,c,h,w = x.shape

        # pad images for PD process
        if h % self.pd_b != 0:
            x = F.pad(x, (0, 0, 0, self.pd_b - h%self.pd_b), mode='constant', value=0)
        if w % self.pd_b != 0:
            x = F.pad(x, (0, self.pd_b - w%self.pd_b, 0, 0), mode='constant', value=0)

        # forward PD-BSN process with inference pd factor
        img_pd_bsn = self.forward(img=x, pd=self.pd_b)

        # Random Replacing Refinement
        if not self.R3:
            ''' Directly return the result (w/o R3) '''
            return img_pd_bsn[:,:,:h,:w]
        else:
            denoised = torch.empty(*(x.shape), self.R3_T, device=x.device)
            for t in range(self.R3_T):
                indice = torch.rand_like(x)
                mask = indice < self.R3_p

                tmp_input = torch.clone(img_pd_bsn).detach()
                tmp_input[mask] = x[mask]
                p = self.pd_pad
                tmp_input = F.pad(tmp_input, (p,p,p,p), mode='reflect')
                if self.pd_pad == 0:
                    denoised[..., t] = self.bsn(tmp_input)
                else:
                    denoised[..., t] = self.bsn(tmp_input)[:,:,p:-p,p:-p]

            return torch.mean(denoised, dim=-1)
            
        '''
        elif self.R3 == 'PD-refinement':
            s = 2
            denoised = torch.empty(*(x.shape), s**2, device=x.device)
            for i in range(s):
                for j in range(s):
                    tmp_input = torch.clone(x_mean).detach()
                    tmp_input[:,:,i::s,j::s] = x[:,:,i::s,j::s]
                    p = self.pd_pad
                    tmp_input = F.pad(tmp_input, (p,p,p,p), mode='reflect')
                    if self.pd_pad == 0:
                        denoised[..., i*s+j] = self.bsn(tmp_input)
                    else:
                        denoised[..., i*s+j] = self.bsn(tmp_input)[:,:,p:-p,p:-p]
            return_denoised = torch.mean(denoised, dim=-1)
        else:
            raise RuntimeError('post-processing type not supported')
        '''

具体网络如下:
结合上图查看

import torch
import torch.nn as nn
import torch.nn.functional as F

from . import regist_model


@regist_model
class DBSNl(nn.Module):
    '''
    Dilated Blind-Spot Network (cutomized light version)

    self-implemented version of the network from "Unpaired Learning of Deep Image Denoising (ECCV 2020)"
    and several modificaions are included. 
    see our supple for more details. 
    '''
    def __init__(self, in_ch=3, out_ch=3, base_ch=128, num_module=9):
        '''
        Args:
            in_ch      : number of input channel
            out_ch     : number of output channel
            base_ch    : number of base channel
            num_module : number of modules in the network
        '''
        super().__init__()

        assert base_ch%2 == 0, "base channel should be divided with 2"

        ly = []
        ly += [ nn.Conv2d(in_ch, base_ch, kernel_size=1) ]
        ly += [ nn.ReLU(inplace=True) ]
        self.head = nn.Sequential(*ly)

        self.branch1 = DC_branchl(2, base_ch, num_module)
        self.branch2 = DC_branchl(3, base_ch, num_module)

        ly = []
        ly += [ nn.Conv2d(base_ch*2,  base_ch,    kernel_size=1) ]
        ly += [ nn.ReLU(inplace=True) ]
        ly += [ nn.Conv2d(base_ch,    base_ch//2, kernel_size=1) ]
        ly += [ nn.ReLU(inplace=True) ]
        ly += [ nn.Conv2d(base_ch//2, base_ch//2, kernel_size=1) ]
        ly += [ nn.ReLU(inplace=True) ]
        ly += [ nn.Conv2d(base_ch//2, out_ch,     kernel_size=1) ]
        self.tail = nn.Sequential(*ly)

    def forward(self, x):
        x = self.head(x)

        br1 = self.branch1(x)
        br2 = self.branch2(x)

        x = torch.cat([br1, br2], dim=1)

        return self.tail(x)

    def _initialize_weights(self):
        # Liyong version
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5)

class DC_branchl(nn.Module):
    def __init__(self, stride, in_ch, num_module):
        super().__init__()

        ly = []
        ly += [ CentralMaskedConv2d(in_ch, in_ch, kernel_size=2*stride-1, stride=1, padding=stride-1) ]
        ly += [ nn.ReLU(inplace=True) ]
        ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=1) ]
        ly += [ nn.ReLU(inplace=True) ]
        ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=1) ]
        ly += [ nn.ReLU(inplace=True) ]

        ly += [ DCl(stride, in_ch) for _ in range(num_module) ]

        ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=1) ]
        ly += [ nn.ReLU(inplace=True) ]
        
        self.body = nn.Sequential(*ly)

    def forward(self, x):
        return self.body(x)

class DCl(nn.Module):
    def __init__(self, stride, in_ch):
        super().__init__()

        ly = []
        ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=stride, dilation=stride) ]
        ly += [ nn.ReLU(inplace=True) ]
        ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=1) ]
        self.body = nn.Sequential(*ly)

    def forward(self, x):
        return x + self.body(x)

class CentralMaskedConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.register_buffer('mask', self.weight.data.clone())
        _, _, kH, kW = self.weight.size()
        self.mask.fill_(1)
        self.mask[:, :, kH//2, kH//2] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super().forward(x)

7.试验

在我自己制作的数据集上训练和测试
一共2中预测方法:A是原文中先unshuffle2+model+shuffle2,B是直接输入model预测

首先按照官方源代码blind spot net试验
pda,pdb分别为5,2的时候, 得到的结果A,B都很平滑
pda,pdb分别为2,2的时候, 得到的结果A,B会稍好一些,多一些细节,但是仍然缺失很多纹理,不是很满意。

然后用unet替换blind spoe net, 网络会学到一个恒等变换,因为input和target是相同的。

总之,没有一个好的结果。

1)为什么论文中在sidd中的表现接近 pair训练,而我自己的数据集利用本方法 与 有监督的pair训练差异巨大,可能是数据集噪声相关度的问题?但是训练的时候 无论pd=2还是pd=5的情况都不太好
2) 场景不同?可能有部分因素,训练的模型在某些 比较干净的图像上降噪效果还可以,但是图像纹理稍微不规则的情况下,就会被抹去
3)虽然也是实际相机拍摄的噪声图像,但是由于 isp处理 相比sRGB不是很标准。噪声形态比较复杂,可能是这个原因。

4)我的训练代码有误?在另一个相对本文改进的方法Random Sub-Samples Generation for Self-Supervised Real Image Denoising我也进行了相关实验,这次直接使用官方代码只改了数据集,可以得到类似的结果,就是blind spot net得到的图像都很平滑。同样在Asymmetric Mask Scheme for Self-Supervised Real Image Denoising方法训练的结果也丢失细节。这三篇论文思路是接近和改进。
5)如何提升效果, 需要更多分析。

  • 25
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值