AP-BSN:
Self-Supervised Denoising for Real-World Images via Asymmetric PD and Blind-Spot Network
1.图像噪声的相关性
bsn 盲点网络 主要用于自监督,空间不相关的噪声。
现实中的图像 噪声一般是相关的。
比如下图:相邻像素的噪声相关性是比较高的
distance = 5之后 相关度就比较低了。
所以制作好自己的数据集之后要看一下空间相关度
2.bsn盲点网路
有理论证明:
但是前提是 像素无关的噪声。
3.pixel subsample方法
就是间隔取值去空间相关性
但是有个问题, 下采样间隔越大,相关性越低,越满足噪声在图像空间上独立不相关
但是下采样间隔越大 同时会引入 混淆问题Aliasing artifacts。
利用pd stride=2或者=5的时候效果都不太好,下图中间2张图:
最右下角是 训练使用5, infer使用2,效果好了很多。
4.Asymmetric PD (APa/b).
We note that a and b are stride factors for training and inference, respectively.
就是 训练的时候 stride=5 去相关
推理的时候 stride=2避免aliasing
5.random-replaceing refinement
就是 随机选取 原始噪声图像中的一些像素 替换 降噪后的对应像素,得到T个新的图像
然后输入到 降噪网络中 得到T个降噪图像,求平均
6.小结
本文主要就是提出4,5来改进 blind spot net
By default, we adopt AP5/2
and set p and T to 0.16 and 8
code: https://github.com/wooseoklee4/AP-BSN
net:
整体网络的结构还是比较清晰的。
如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..util.util import pixel_shuffle_down_sampling, pixel_shuffle_up_sampling
from . import regist_model
from .DBSNl import DBSNl
@regist_model
class APBSN(nn.Module):
'''
Asymmetric PD Blind-Spot Network (AP-BSN)
'''
def __init__(self, pd_a=5, pd_b=2, pd_pad=2, R3=True, R3_T=8, R3_p=0.16,
bsn='DBSNl', in_ch=3, bsn_base_ch=128, bsn_num_module=9):
'''
Args:
pd_a : 'PD stride factor' during training
pd_b : 'PD stride factor' during inference
pd_pad : pad size between sub-images by PD process
R3 : flag of 'Random Replacing Refinement'
R3_T : number of masks for R3
R3_p : probability of R3
bsn : blind-spot network type
in_ch : number of input image channel
bsn_base_ch : number of bsn base channel
bsn_num_module : number of module
'''
super().__init__()
# network hyper-parameters
self.pd_a = pd_a
self.pd_b = pd_b
self.pd_pad = pd_pad
self.R3 = R3
self.R3_T = R3_T
self.R3_p = R3_p
# define network
if bsn == 'DBSNl':
self.bsn = DBSNl(in_ch, in_ch, bsn_base_ch, bsn_num_module)
else:
raise NotImplementedError('bsn %s is not implemented'%bsn)
def forward(self, img, pd=None):
'''
Foward function includes sequence of PD, BSN and inverse PD processes.
Note that denoise() function is used during inference time (for differenct pd factor and R3).
'''
# default pd factor is training factor (a)
if pd is None: pd = self.pd_a
# do PD
if pd > 1:
pd_img = pixel_shuffle_down_sampling(img, f=pd, pad=self.pd_pad)
else:
p = self.pd_pad
pd_img = F.pad(img, (p,p,p,p))
# forward blind-spot network
pd_img_denoised = self.bsn(pd_img)
# do inverse PD
if pd > 1:
img_pd_bsn = pixel_shuffle_up_sampling(pd_img_denoised, f=pd, pad=self.pd_pad)
else:
p = self.pd_pad
img_pd_bsn = pd_img_denoised[:,:,p:-p,p:-p]
return img_pd_bsn
def denoise(self, x):
'''
Denoising process for inference.
'''
b,c,h,w = x.shape
# pad images for PD process
if h % self.pd_b != 0:
x = F.pad(x, (0, 0, 0, self.pd_b - h%self.pd_b), mode='constant', value=0)
if w % self.pd_b != 0:
x = F.pad(x, (0, self.pd_b - w%self.pd_b, 0, 0), mode='constant', value=0)
# forward PD-BSN process with inference pd factor
img_pd_bsn = self.forward(img=x, pd=self.pd_b)
# Random Replacing Refinement
if not self.R3:
''' Directly return the result (w/o R3) '''
return img_pd_bsn[:,:,:h,:w]
else:
denoised = torch.empty(*(x.shape), self.R3_T, device=x.device)
for t in range(self.R3_T):
indice = torch.rand_like(x)
mask = indice < self.R3_p
tmp_input = torch.clone(img_pd_bsn).detach()
tmp_input[mask] = x[mask]
p = self.pd_pad
tmp_input = F.pad(tmp_input, (p,p,p,p), mode='reflect')
if self.pd_pad == 0:
denoised[..., t] = self.bsn(tmp_input)
else:
denoised[..., t] = self.bsn(tmp_input)[:,:,p:-p,p:-p]
return torch.mean(denoised, dim=-1)
'''
elif self.R3 == 'PD-refinement':
s = 2
denoised = torch.empty(*(x.shape), s**2, device=x.device)
for i in range(s):
for j in range(s):
tmp_input = torch.clone(x_mean).detach()
tmp_input[:,:,i::s,j::s] = x[:,:,i::s,j::s]
p = self.pd_pad
tmp_input = F.pad(tmp_input, (p,p,p,p), mode='reflect')
if self.pd_pad == 0:
denoised[..., i*s+j] = self.bsn(tmp_input)
else:
denoised[..., i*s+j] = self.bsn(tmp_input)[:,:,p:-p,p:-p]
return_denoised = torch.mean(denoised, dim=-1)
else:
raise RuntimeError('post-processing type not supported')
'''
具体网络如下:
结合上图查看
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import regist_model
@regist_model
class DBSNl(nn.Module):
'''
Dilated Blind-Spot Network (cutomized light version)
self-implemented version of the network from "Unpaired Learning of Deep Image Denoising (ECCV 2020)"
and several modificaions are included.
see our supple for more details.
'''
def __init__(self, in_ch=3, out_ch=3, base_ch=128, num_module=9):
'''
Args:
in_ch : number of input channel
out_ch : number of output channel
base_ch : number of base channel
num_module : number of modules in the network
'''
super().__init__()
assert base_ch%2 == 0, "base channel should be divided with 2"
ly = []
ly += [ nn.Conv2d(in_ch, base_ch, kernel_size=1) ]
ly += [ nn.ReLU(inplace=True) ]
self.head = nn.Sequential(*ly)
self.branch1 = DC_branchl(2, base_ch, num_module)
self.branch2 = DC_branchl(3, base_ch, num_module)
ly = []
ly += [ nn.Conv2d(base_ch*2, base_ch, kernel_size=1) ]
ly += [ nn.ReLU(inplace=True) ]
ly += [ nn.Conv2d(base_ch, base_ch//2, kernel_size=1) ]
ly += [ nn.ReLU(inplace=True) ]
ly += [ nn.Conv2d(base_ch//2, base_ch//2, kernel_size=1) ]
ly += [ nn.ReLU(inplace=True) ]
ly += [ nn.Conv2d(base_ch//2, out_ch, kernel_size=1) ]
self.tail = nn.Sequential(*ly)
def forward(self, x):
x = self.head(x)
br1 = self.branch1(x)
br2 = self.branch2(x)
x = torch.cat([br1, br2], dim=1)
return self.tail(x)
def _initialize_weights(self):
# Liyong version
for m in self.modules():
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5)
class DC_branchl(nn.Module):
def __init__(self, stride, in_ch, num_module):
super().__init__()
ly = []
ly += [ CentralMaskedConv2d(in_ch, in_ch, kernel_size=2*stride-1, stride=1, padding=stride-1) ]
ly += [ nn.ReLU(inplace=True) ]
ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=1) ]
ly += [ nn.ReLU(inplace=True) ]
ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=1) ]
ly += [ nn.ReLU(inplace=True) ]
ly += [ DCl(stride, in_ch) for _ in range(num_module) ]
ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=1) ]
ly += [ nn.ReLU(inplace=True) ]
self.body = nn.Sequential(*ly)
def forward(self, x):
return self.body(x)
class DCl(nn.Module):
def __init__(self, stride, in_ch):
super().__init__()
ly = []
ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=stride, dilation=stride) ]
ly += [ nn.ReLU(inplace=True) ]
ly += [ nn.Conv2d(in_ch, in_ch, kernel_size=1) ]
self.body = nn.Sequential(*ly)
def forward(self, x):
return x + self.body(x)
class CentralMaskedConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer('mask', self.weight.data.clone())
_, _, kH, kW = self.weight.size()
self.mask.fill_(1)
self.mask[:, :, kH//2, kH//2] = 0
def forward(self, x):
self.weight.data *= self.mask
return super().forward(x)
7.试验
在我自己制作的数据集上训练和测试
一共2中预测方法:A是原文中先unshuffle2+model+shuffle2,B是直接输入model预测
首先按照官方源代码blind spot net试验
pda,pdb分别为5,2的时候, 得到的结果A,B都很平滑
pda,pdb分别为2,2的时候, 得到的结果A,B会稍好一些,多一些细节,但是仍然缺失很多纹理,不是很满意。
然后用unet替换blind spoe net, 网络会学到一个恒等变换,因为input和target是相同的。
总之,没有一个好的结果。
1)为什么论文中在sidd中的表现接近 pair训练,而我自己的数据集利用本方法 与 有监督的pair训练差异巨大,可能是数据集噪声相关度的问题?但是训练的时候 无论pd=2还是pd=5的情况都不太好
2) 场景不同?可能有部分因素,训练的模型在某些 比较干净的图像上降噪效果还可以,但是图像纹理稍微不规则的情况下,就会被抹去
3)虽然也是实际相机拍摄的噪声图像,但是由于 isp处理 相比sRGB不是很标准。噪声形态比较复杂,可能是这个原因。
4)我的训练代码有误?在另一个相对本文改进的方法Random Sub-Samples Generation for Self-Supervised Real Image Denoising我也进行了相关实验,这次直接使用官方代码只改了数据集,可以得到类似的结果,就是blind spot net得到的图像都很平滑。同样在Asymmetric Mask Scheme for Self-Supervised Real Image Denoising方法训练的结果也丢失细节。这三篇论文思路是接近和改进。
5)如何提升效果, 需要更多分析。