提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
AP-BSN网络是在BSN网络基础上改进了shuffle机制,打破了自然图像中噪声具有空间连续性的特点。该文章主要创新点有三:
- 提出了非对称的pixshuffle-downsample(AP);
- 采取了随即替换的精细化方法;
- 是第一个面向真实世界图像去噪的自监督网络。
一、BSN 和 Pisshuffle Downsample
盲点网络(Blind-spot network,BSN)是卷积神经网络的一种变种,如图所示,该变种将卷积核的中心点置零,即作为一种无视中心像素的网络。
该网络中需要满足两条假设才可以进行训练:
- 假设一:噪声在空间上是像素级的;(这一点在真实图像中并不满足)
- 假设二:噪声是独立且零均值的 (这一点在深空探测中并不满足)
由于实际空间中多数图像的噪声都是空间相关的,SIDD和DND数据集中的噪声都说明了该问题,因此假设一并不被满足,为此有学者设计了一种下采样方式,通俗得理解为等间隔取点,或者以不同的相位进行多组池化,如图所示。
这种操作带来了两个影响:打破了空间相关性和造成了伪影,即右下角绿框图像中红色箭头所指的位置本是一个木杆,但是在PD过程中成为一个孤立的像元,BSN的推理过程便会将该位置识别为噪声进行抑制。
二、改进思路
1、非对称下采样
非对称下采样即使名称中AP得由来,也是该文章最重要得创新点。简单讲,即是在训练过程中使用
P
D
5
PD_5
PD5打破空间相关性,在推理过程中选用
P
D
2
PD_2
PD2尽可能地抑制伪影得产生。
PD得代码为:
def pixel_shuffle_down_sampling(x:torch.Tensor, f:int, pad:int=0, pad_value:float=0.):
'''
pixel-shuffle down-sampling (PD) from "When AWGN-denoiser meets real-world noise." (AAAI 2019)
Args:
x (Tensor) : input tensor
f (int) : factor of PD
pad (int) : number of pad between each down-sampled images
pad_value (float) : padding value
Return:
pd_x (Tensor) : down-shuffled image tensor with pad or not
'''
# single image tensor
if len(x.shape) == 3:
c,w,h = x.shape
unshuffled = F.pixel_unshuffle(x, f)
if pad != 0: unshuffled = F.pad(unshuffled, (pad, pad, pad, pad), value=pad_value)
return unshuffled.view(c,f,f,w//f+2*pad,h//f+2*pad).permute(0,1,3,2,4).reshape(c, w+2*f*pad, h+2*f*pad)
# batched image tensor
else:
b,c,w,h = x.shape
unshuffled = F.pixel_unshuffle(x, f)
if pad != 0: unshuffled = F.pad(unshuffled, (pad, pad, pad, pad), value=pad_value)
return unshuffled.view(b,c,f,f,w//f+2*pad,h//f+2*pad).permute(0,1,2,4,3,5).reshape(b,c,w+2*f*pad, h+2*f*pad)
def pixel_shuffle_up_sampling(x:torch.Tensor, f:int, pad:int=0):
'''
inverse of pixel-shuffle down-sampling (PD)
see more details about PD in pixel_shuffle_down_sampling()
Args:
x (Tensor) : input tensor
f (int) : factor of PD
pad (int) : number of pad will be removed
'''
# single image tensor
if len(x.shape) == 3:
c,w,h = x.shape
before_shuffle = x.view(c,f,w//f,f,h//f).permute(0,1,3,2,4).reshape(c*f*f,w//f,h//f)
if pad != 0: before_shuffle = before_shuffle[..., pad:-pad, pad:-pad]
return F.pixel_shuffle(before_shuffle, f)
# batched image tensor
else:
b,c,w,h = x.shape
before_shuffle = x.view(b,c,f,w//f,f,h//f).permute(0,1,2,4,3,5).reshape(b,c*f*f,w//f,h//f)
if pad != 0: before_shuffle = before_shuffle[..., pad:-pad, pad:-pad]
return F.pixel_shuffle(before_shuffle, f)
2、随即替换优化(R3)
最经典的图像去噪算法即是均值滤波。作为一种低通滤波器,其最主要的影响便是抑制了图像中的高频部分。无论算法如何发展,最重要的目的依然是保留有结构性的高频信息。有方法利用固定的人早点去保留结构性的高频信息,本文提出采用随即的mask去组合图像与BSN网络处理后的图像,并进行多个模板的累加,从而实现某些高频信息的保留。
I M i = M i ⋅ I N o i s y + ( 1 − M i ) ⋅ I B S N s I_{M_{i}}=M_{i}·I_{Noisy}+(1-M_{i})·I^{s}_{BSN} IMi=Mi⋅INoisy+(1−Mi)⋅IBSNs
代码如下(示例):
def forward(self, img, pd=None):
'''
Foward function includes sequence of PD, BSN and inverse PD processes.
Note that denoise() function is used during inference time (for differenct pd factor and R3).
'''
# default pd factor is training factor (a)
if pd is None: pd = self.pd_a
# do PD
if pd > 1:
pd_img = pixel_shuffle_down_sampling(img, f=pd, pad=self.pd_pad)
else:
p = self.pd_pad
pd_img = F.pad(img, (p,p,p,p))
# forward blind-spot network
pd_img_denoised = self.bsn(pd_img)
# do inverse PD
if pd > 1:
img_pd_bsn = pixel_shuffle_up_sampling(pd_img_denoised, f=pd, pad=self.pd_pad)
else:
p = self.pd_pad
img_pd_bsn = pd_img_denoised[:,:,p:-p,p:-p]
return img_pd_bsn
def denoise(self, x):
'''
Denoising process for inference.
'''
b,c,h,w = x.shape
# pad images for PD process
if h % self.pd_b != 0:
x = F.pad(x, (0, 0, 0, self.pd_b - h%self.pd_b), mode='constant', value=0)
if w % self.pd_b != 0:
x = F.pad(x, (0, self.pd_b - w%self.pd_b, 0, 0), mode='constant', value=0)
# forward PD-BSN process with inference pd factor
img_pd_bsn = self.forward(img=x, pd=self.pd_b)
# Random Replacing Refinement
if not self.R3:
''' Directly return the result (w/o R3) '''
return img_pd_bsn[:,:,:h,:w]
else:
denoised = torch.empty(*(x.shape), self.R3_T, device=x.device)
for t in range(self.R3_T):
indice = torch.rand_like(x)
mask = indice < self.R3_p
tmp_input = torch.clone(img_pd_bsn).detach()
tmp_input[mask] = x[mask]
p = self.pd_pad
tmp_input = F.pad(tmp_input, (p,p,p,p), mode='reflect')
if self.pd_pad == 0:
denoised[..., t] = self.bsn(tmp_input)
else:
denoised[..., t] = self.bsn(tmp_input)[:,:,p:-p,p:-p]
return torch.mean(denoised, dim=-1)
需要注意的是R3算法仅在推理过程中应用
总结
该文章发表于2022年的CVPR,其创新点在于首次实现了现实噪声图像的自监督训练,其余两个创新点均是在应用和方法上进行的小的改进,但是该文章的优点在于对理论和故事讲得非常完善,该文章来自韩国首尔大学,CVPR2023中另有一篇文章在此文章上进行改进,是下一篇的学习目标。另外该文章代码较为规范,内容比较充实,可以作为基础框架使用。