Multi-Stage Progressive Image Restoration(多阶段渐进图像修复)
最近组会读了一篇cvpr的论文,做了一点笔记,内容包含我的理解以及翻译的论文(含注解)以及代码。
读后有感
结果展示:
去模糊展示
gopro数据集samlple_0原图片:
去模糊结果:
去雨展示
原图:
去雨结果:
去噪展示
无损:
原图:
结果:
tips:2276*1280的像素大小的图片,在去模糊以及去噪任务中,推理过程所需超过12g显存。
论文背景
现状很少有工作将多阶段网络结构引入图像修复问题中。
现有多阶段与单阶段各有优缺点。
多阶段
优点:拥有较大的感受野
缺点:缺少局部空间细节
单阶段
优点:拥有精确的空间细节
缺点:感受野较小
目标与贡献
所以本文提出了一种新的跨阶段特征融合方法
且在渐进修复的每个阶段提供真实的监督是很重要的。
所以本文提出一个有效的监督注意力模块,充分利用每个阶段修复的图像在进一步传播之前对输入特征进行细化。
不同的任务的细节
MPRNET训练了三种任务(去雨、去模糊、去噪)的模型三种模型只在模型的特征通道数量上有所不同(去模糊>去噪>去雨)。
但是训练细节方面去噪任务中的训练所使用的损失函数仅为
,而去雨与去模糊中使用的损失函数不同。
而且去噪的训练过程中还有数据增强。代码如下所示:
class MixUp_AUG:
def __init__(self):
self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
def aug(self, rgb_gt, rgb_noisy):#随机抽取批次数量的ground Truth与noisy图片并利用beta分布采样出一个随机数向量然后进行对每一对图片进行融合,这种数据增强技术可以增加模型训练的多样性,帮助模型更好地学习数据的特征,提高泛化能力。
bs = rgb_gt.size(0)
indices = torch.randperm(bs)
rgb_gt2 = rgb_gt[indices]
rgb_noisy2 = rgb_noisy[indices]
lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
return rgb_gt, rgb_noisy
代码中定义的conv函数为:
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias, stride = stride)
通道注意力模块为:
使用一个卷积一个激活函数再接一个卷积之后,使用全局平均卷积层将每个通道的变为1*1之后使用卷积注意于提取每个通道的特征的信息。
## Channel Attention Layer
class CALayer(nn.Module):
def __init__(self, channel, reduction=16, bias=False):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
class CAB(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, bias, act):
super(CAB, self).__init__()
modules_body = []
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
modules_body.append(act)
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
self.CA = CALayer(n_feat, reduction, bias=bias)
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res = self.CA(res)
res += x
return res
ORSNET:![](https://img-blog.csdnimg.cn/403ab6c208434f3d940823ce1d524da8.png)
(上图为ORB,下图为ORSNET)
## Original Resolution Block (ORB)
class ORB(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab):
super(ORB, self).__init__()
modules_body = []
modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
class ORSNet(nn.Module):
def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab):
super(ORSNet, self).__init__()
self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
self.up_enc1 = UpSample(n_feat, scale_unetfeats)
self.up_dec1 = UpSample(n_feat, scale_unetfeats)
self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
def forward(self, x, encoder_outs, decoder_outs):
x = self.orb1(x)
x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0])
x = self.orb2(x)
x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1]))
x = self.orb3(x)
x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2]))
return x
##
下采样与上采样模块定义:
使用双线性插值之后进行卷积,而不使用转置卷积就是为了避免出现棋盘格伪影的现象。
class DownSample(nn.Module):
def __init__(self, in_channels,s_factor):
super(DownSample, self).__init__()
self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False))
def forward(self, x):
x = self.down(x)
return x
class UpSample(nn.Module):
def __init__(self, in_channels,s_factor):
super(UpSample, self).__init__()
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))
def forward(self, x):
x = self.up(x)
return x
译码器解码器结构:
使用u-net的方式多尺度提取特征,是为了顾全细节于全局信息
下面该模块用于u-net中,作用是合并上采样第一个输入参数并与第二个输入参数合并。
class SkipUpSample(nn.Module):
def __init__(self, in_channels,s_factor):
super(SkipUpSample, self).__init__()
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))
def forward(self, x, y):
x = self.up(x)
x = x + y
return x
编码器
## U-Net
class Encoder(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff):
super(Encoder, self).__init__()
self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
self.encoder_level1 = nn.Sequential(*self.encoder_level1)
self.encoder_level2 = nn.Sequential(*self.encoder_level2)
self.encoder_level3 = nn.Sequential(*self.encoder_level3)
self.down12 = DownSample(n_feat, scale_unetfeats)
self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats)
# Cross Stage Feature Fusion (CSFF)
if csff:
self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias)
self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias)
self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
def forward(self, x, encoder_outs=None, decoder_outs=None):
enc1 = self.encoder_level1(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
x = self.down12(enc1)
enc2 = self.encoder_level2(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
x = self.down23(enc2)
enc3 = self.encoder_level3(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
return [enc1, enc2, enc3]
解码器
class Decoder(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats):
super(Decoder, self).__init__()
self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
self.decoder_level1 = nn.Sequential(*self.decoder_level1)
self.decoder_level2 = nn.Sequential(*self.decoder_level2)
self.decoder_level3 = nn.Sequential(*self.decoder_level3)
self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
self.up21 = SkipUpSample(n_feat, scale_unetfeats)
self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats)
def forward(self, outs):
enc1, enc2, enc3 = outs
dec3 = self.decoder_level3(enc3)
x = self.up32(dec3, self.skip_attn2(enc2))
dec2 = self.decoder_level2(x)
x = self.up21(dec2, self.skip_attn1(enc1))
dec1 = self.decoder_level1(x)
return [dec1,dec2,dec3]
监督注意模块(SAM):
使用原图片生成掩码(注意于哪些特征对下一个阶段更加有利)。三个卷积都是1*1的卷积。
卷积之后将输出与受损的图片相加并再经过一个1*1卷积得到注意力掩码之后乘以另一条分支再与上个阶段的输出相加得到传递给下一个阶段的输出。
为什么称为监督注意力呢?
因为他在生成注意力之前的部分给他连接了个loss使得用于生成注意力的部分是本阶段中最像ground truth的,然后经过卷积以及sgmoid生成的注意力,消融实验也证明这个模块是有效的,经过loss监督之后的部分的输出最接近ground truth是这样生成的注意力更好的。
class SAM(nn.Module):
def __init__(self, n_feat, kernel_size, bias):
super(SAM, self).__init__()
self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
self.conv3 = conv(3, n_feat, kernel_size, bias=bias)
def forward(self, x, x_img):
x1 = self.conv1(x)
img = self.conv2(x) + x_img#注意力部分
x2 = torch.sigmoid(self.conv3(img))#注意力部分
x1 = x1*x2
x1 = x1+x
return x1, img
#self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)#MPRNET中的监督注意力模块初始化
#self.sam23 = SAM(n_feat, kernel_size=1, bias=bias)
MPRNET(与跨阶段提取特征):![](https://img-blog.csdnimg.cn/317b66e901cf491592bccb3200ca21ba.png)
csff的优点,它使网络在编码器-解码器中由于重复使用上、下采样操作而减少信息损失。第二,一个阶段的多尺度特征有助于丰富下一个阶段的特征。第三,网络优化过程变得更加稳定,因为它缓解了信息的流动,从而允许我们在整体架构中添加几个阶段。(我认为也是跨阶段特征融合的优点)
MPRNET在使用了跨阶段特征融合、SAM等技术使得不同阶段可以提取不同的特征并且有效的传递到下一个阶段并进行融合。又使用ORSNET去融合不同通道的特征(不进行下采样)保证了空间细节又有了比较大的感受野。
MPRNET代码:
class MPRNet(nn.Module):
def __init__(self, in_c=3, out_c=3, n_feat=40, scale_unetfeats=20, scale_orsnetfeats=16, num_cab=8, kernel_size=3, reduction=4, bias=False):
super(MPRNet, self).__init__()
act=nn.PReLU()
self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
# Cross Stage Feature Fusion (CSFF)
self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False)
self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab)
self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)
self.sam23 = SAM(n_feat, kernel_size=1, bias=bias)
self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias)
self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias)
self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias)
def forward(self, x3_img):
# Original-resolution Image for Stage 3
H = x3_img.size(2)
W = x3_img.size(3)
# Multi-Patch Hierarchy: Split Image into four non-overlapping patches
# Two Patches for Stage 2
x2top_img = x3_img[:,:,0:int(H/2),:]
x2bot_img = x3_img[:,:,int(H/2):H,:]
# Four Patches for Stage 1
x1ltop_img = x2top_img[:,:,:,0:int(W/2)]
x1rtop_img = x2top_img[:,:,:,int(W/2):W]
x1lbot_img = x2bot_img[:,:,:,0:int(W/2)]
x1rbot_img = x2bot_img[:,:,:,int(W/2):W]
##-------------------------------------------
##-------------- Stage 1---------------------
##-------------------------------------------
## Compute Shallow Features
x1ltop = self.shallow_feat1(x1ltop_img)
x1rtop = self.shallow_feat1(x1rtop_img)
x1lbot = self.shallow_feat1(x1lbot_img)
x1rbot = self.shallow_feat1(x1rbot_img)
## Process features of all 4 patches with Encoder of Stage 1
feat1_ltop = self.stage1_encoder(x1ltop)
feat1_rtop = self.stage1_encoder(x1rtop)
feat1_lbot = self.stage1_encoder(x1lbot)
feat1_rbot = self.stage1_encoder(x1rbot)
## Concat deep features
feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)]
feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)]
## Pass features through Decoder of Stage 1
res1_top = self.stage1_decoder(feat1_top)
res1_bot = self.stage1_decoder(feat1_bot)
## Apply Supervised Attention Module (SAM)
x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img)
x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img)
## Output image at Stage 1
stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2)
##-------------------------------------------
##-------------- Stage 2---------------------
##-------------------------------------------
## Compute Shallow Features
x2top = self.shallow_feat2(x2top_img)
x2bot = self.shallow_feat2(x2bot_img)
## Concatenate SAM features of Stage 1 with shallow features of Stage 2
x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1))
x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1))
## Process features of both patches with Encoder of Stage 2
feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top)
feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot)
## Concat deep features
feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)]
## Pass features through Decoder of Stage 2
res2 = self.stage2_decoder(feat2)
## Apply SAM
x3_samfeats, stage2_img = self.sam23(res2[0], x3_img)
##-------------------------------------------
##-------------- Stage 3---------------------
##-------------------------------------------
## Compute Shallow Features
x3 = self.shallow_feat3(x3_img)
## Concatenate SAM features of Stage 2 with shallow features of Stage 3
x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1))
x3_cat = self.stage3_orsnet(x3_cat, feat2, res2)
stage3_img = self.tail(x3_cat)
return [stage3_img+x3_img, stage2_img, stage1_img]
评价指标的理解:
(下面一部分是对SSIM与PSNR的理解,SSIM与PSNR可以用来计算两张图片的相似度,但是这些基于几何的相似度方法在(去雨、去雨、去噪等图像较为平滑)图像修复问题上不如LPIPS)
PSNR:
PSNR值越大,表示图像的质量越好,一般来说:
(1)高于40dB:说明图像质量极好(即非常接近原始图像)
(2)30—40dB:通常表示图像质量是好的(即失真可以察觉但可以接受)
(3)20—30dB:说明图像质量差
(4)低于20dB:图像质量不可接受
SSIM:
传统基于MSE的损失不足以表达人的视觉系统对图片的直观感受。例如有时候两张图片只是亮度不同,但是之间的MSEloss相差很大。而一幅很模湖与另一幅很清的图,它们的MSE loss可能反而相差很小。
SSIM使用灰度的平均值来衡量亮度、使用灰度的标准差来衡量对比度、使用皮尔逊相关性来衡量结构。
C1、C2与C3是常数项用来防止除0。
每次计算的时候都从图片上取一个N×N的窗口,然后不断滑动窗口进行计算,最后取平均值作为全局的SSIM。
LPIPS:
PNSR与SSIM忽略了边缘信息,导致较为平滑的图像得到的分数更高。
《The unreasonable effectiveness of deep features as a perceptual metric.》论文简介
虽然人类几乎不费力地快速评估两幅图像之间的感知相似性,但其潜在的过程被认为是相当复杂的。尽管如此,如今使用最广泛的感知指标,如PSNR和SSIM,都是简单的、浅层的函数,并没有考虑到人类感知的许多细微差别。最近,深度学习社区发现,在ImageNet分类上训练的VGG网络的特征作为图像合成的训练损失非常有用。但这些所谓的"知觉损失"是怎样的?哪些因素对他们的成功至关重要?为了回答这些问题,我们引入了一个新的人类知觉相似性判断数据集。我们系统地评估了跨越不同架构和tas的深层特征。
用深度的网络来比较图片可以得到更好的结果。
LPIPS计算成本相对较高,可能会导致训练过程更慢,因此在实际应用中需要权衡计算资源和性能需求。
损失函数
在任一给定阶段S,所提出的模型不是直接预测复原图像,而是预测残差图像
,将退化的输入图像I加入残差图像
,得到:
。我们使用如下损失函数对MPRNet进行端到端的优化:
式中Y表示标签图像(没有受损的图像),为Charbonnier损失函数[ 12 ]:
固定ε,所有实验经验设定为。此外,
是边缘损失,定义为:
式中:∆表示Laplacian算子
可以看到损失函数是每个阶段的预测图像的与λ(0.05)乘于
之和。(去噪任务中只有
)
因为在图像修复任务中是需要将图像锐化,所以损失函数需要注意于边缘信息。所以设计了
明明论文中指出使用Laplacian算子用于计算边缘信息,但是在代码中使用拉普拉斯金字塔的方式进行计算边缘信息,虽然两者都可以计算但是效果一样吗?
下面是代码中定义的损失函数:
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-3):
super(CharbonnierLoss, self).__init__()
self.eps = eps
def forward(self, x, y):
diff = x - y
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
return loss
class EdgeLoss(nn.Module):
def __init__(self):
super(EdgeLoss, self).__init__()
k = torch.Tensor([[.05, .25, .4, .25, .05]])
self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)#最后形状为(3,1,5,5)
if torch.cuda.is_available():
self.kernel = self.kernel.cuda()
self.loss = CharbonnierLoss()
def conv_gauss(self, img):#高斯模糊
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):#利用高斯模糊求出拉普拉斯金字塔(图片的高频信息)
filtered = self.conv_gauss(current) # filter
down = filtered[:,:,::2,::2] # downsample
new_filter = torch.zeros_like(filtered)
new_filter[:,:,::2,::2] = down*4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def forward(self, x, y):
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
#train.py中的损失函数计算
# loss_char = np.sum([criterion_char(restored[j],target) for j in range(len(restored))])
# loss_edge = np.sum([criterion_edge(restored[j],target) for j in range(len(restored))])
# loss = (loss_char) + (0.05*loss_edge)
下面的部分为我翻译的论文(含注解)
Multi-Stage Progressive Image Restoration(多阶段渐进图像修复)
简介
图像修复任务在修复图像时需要空间细节和高级上下文信息之间的复杂平衡。在本文中,我们提出了一种新颖的协同设计,可以最佳地平衡这些相互竞争的目标。我们的主要建议是多阶段架构,它逐步学习降级输入的修复函数,从而将整个修复过程分解为更易于管理的步骤。具体来说,我们的模型首先使用编码器-解码器架构学习上下文特征,然后将它们与保留局部信息的高分辨率分支相结合。在每个阶段,我们都引入了一种新颖的每像素自适应设计(per-pixel adaptive design),利用原位监督注意力来重新加权局部特征。这种多阶段架构的一个关键要素是不同阶段之间的信息交换(多级特征融合、特征金字塔)。为此,我们提出了一种双方面的方法,其中信息不仅从早期到后期顺序交换,而且特征处理块之间也存在横向连接以避免任何信息丢失。由此产生的紧密互连的多级架构(称为 MPRNet)在图像去雨、去模糊和去噪等一系列任务中的 10 个数据集上提供了强大的性能提升。源代码和预训练模型可在 GitHub - swz30/MPRNet: [CVPR 2021] Multi-Stage Progressive Image Restoration. SOTA results for Image deblurring, deraining, and denoising. 获取。
介绍
图像复原是指从受损的图像中修复出一幅干净图像的任务。受损的典型例子包括噪声、模糊、雨、雾霾等。这是一个高度不适定的问题,因为存在无穷多个可行解。为了将解空间限制在有效/自然图像上,现有的复原技术[ 19,29,39,59,66,67,100]显式地使用了由经验观测手工制作的图像先验。然而,设计这样的先验是一项具有挑战性的任务,而且往往不能泛化。为了改善这个问题,最近的最先进的方法[ 17、44、57、86、87、93、94、97]使用卷积神经网络( CNNs ),通过从大规模数据中捕获自然图像统计来隐式地学习更一般的先验。
基于CNN的方法相对于其他方法的性能增益主要归因于其模型设计。已经开发了许多用于图像复原的网络模块和功能单元,包括递归残差学习[ 4,95]、空洞卷积[ 4,81]、注意力机制[ 17,86,96]、密集连接[ 73,75,97]、编码器-解码器[ 7、13、43、65]和生成模型[ 44,62,90,92]。然而,几乎所有这些针对低层视觉问题的模型都是基于单阶段设计的。相比之下,多级网络在诸如姿势估计器[ 14、46、54]、场景解析[ 15 ]和动作分割[ 20,26,45]等高级视觉问题中表现出比单级网络更有效的性能。
最近,很少有工作将多级网络设计引入到图像去模糊[ 70、71、88 ]和图像去雨[ 47、63 ]中。我们对这些方法进行分析,以确定妨碍其性能的架构瓶颈。首先,现有的多阶段技术要么使用编码器-解码器架构[ 71、88 ],该架构有效地编码了广泛的上下文信息,但在保留了空间图像细节方面不可靠;要么使用单尺度管道[ 63 ],该结构提供了空间精确但是在语义上缺少可靠的输出。然而,我们表明,有效的图像复原需要在多级架构中结合两种设计选择。其次,我们证明单纯地将一个阶段的输出传递到下一个阶段会产生次优的结果[ 53 ]。第三,与[ 88 ]中不同的是,在渐进修复的每个阶段提供真实的监督是很重要的。最后,在多阶段处理过程中,需要一种从早期到后期传播中间特征的机制来保存来自编码器-解码器分支的上下文特征。
我们提出了一种多级递进的图像复原结构,称为MPRNet,它包含几个关键组件。1 )。前几个阶段使用编码器-解码器来学习多尺度上下文信息,而最后一个阶段对原始图像分辨率进行操作,以保留精细的空间细节。2 ) .每两个阶段之间插入一个监督注意力模块( SAM ),以实现渐进式学习。在真实图像的指导下,该模块利用上一阶段的预测来计算注意力图,注意力图反过来用于在传递到下一阶段之前优化上一阶段的特征。3 ) .增加了跨阶段特征融合( CSFF )机制,有助于将多尺度上下文特征从前阶段传播到后阶段。此外,该方法缓解了阶段间的信息流动,有效地稳定了多阶段网络优化。
这项工作的主要贡献是:
•一种新颖的多阶段方法,能够产生上下文丰富和空间准确的输出。由于其多阶段的性质,我们的框架将具有 挑战性的图像复原任务分解为子任务,以逐步修复退化图像;
•一个有效的监督注意力模块,充分利用每个阶段修复的图像在进一步传播之前对输入特征进行细化;
•一个跨阶段聚合多尺度特征的策略;
•我们通过在10个合成和真实数据集上设置新的最先进的MPRNet来证明我们的MPRNet的有效性,包括图像 去雨、去模糊和去噪,同时保持低复杂度(见图1)。进一步,我们提供了详细的消融、定性结果和泛化测试。
相关工作
近年来出现了从高端数码单反相机到智能手机相机的范式转变。然而,使用智能手机相机拍摄高质量图像具有挑战性。由于相机的限制和/或不利的环境条件,图像中图像往往存在受损。早期的复原方法有基于全变分的[ 10,67]、基于稀疏编码的[ 3、51、52]、基于自相似性的[ 8、16]、基于梯度先验的[ 68、80 ]等。最近,基于CNN的复原方法取得了[ 57、70、86、93、97]的最新结果。在模型结构设计方面,这些方法大致可以分为单阶段和多阶段。
单阶段方法。目前,大多数图像复原方法都是基于单阶段设计的,而模型模块通常是基于高层视觉任务开发的。例如,残差学习[ 30 ]已经用于图像去噪[ 2,72,93]、图像去模糊[ 42、43]和图像去雨[ 37 ]。同样,为了提取多尺度信息,[ 4,28,43]常用编码器-解码器[ 65 ]和空洞卷积[ 83 ]模型。其他的单阶段方法[ 5,89,97]引入了密集连接[ 34 ]。
多阶段方法。这些方法[ 24、47、53、63、70、71、88、99]旨在通过在每个阶段使用轻量级子网络以渐进的方式修复干净的图像。这样的设计是有效的,因为它将具有挑战性的图像复原任务分解为更小的更容易的子任务。然而,通常的做法是在每个阶段使用相同的子网络,这可能会产生次优的结果,如我们的实验(第4节)所示。
注意力。在其在图像分类[ 31、32、79]、分割[ 21、35]和检测[ 74,79]等高层任务中取得成功的推动下,注意力模块被用于低层视觉任务[ 38 ]。包括图像去雨方法[ 37、47]、去模糊方法[ 61、70 ]、超分辨率方法[ 17,95]、去噪方法[ 4,86]等。其主要思想是捕捉空间维度[ 98 ]、通道维度[ 32 ]或两者[ 79 ]之间的长程相关性。
多阶段渐进修复
所提出的图像修复框架,如图2所示,由三个阶段来逐步修复图像。前两个阶段基于编码器-解码器子网络,由于较大的感受野而学习广泛的上下文信息。由于图像修复是一个位置敏感的任务(这就需要从输入到输出进行像素到像素的对应),最后一个阶段使用一个子网络对原始输入图像分辨率(不进行任何下采样操作)进行操作,从而在最终的输出图像中保留所需的精细纹理。
我们不是简单地级联多个阶段,而是在每两个阶段之间合并一个有监督的注意力模块。在真值图像的监督下,我们的模块在将上一阶段的特征图传递到下一阶段之前对其进行重新缩放。此外,我们引入了跨阶段特征融合机制,前一个子网络的中间多尺度上下文特征有助于巩固后一个子网络的中间特征。
尽管MPRNet叠加了多个阶段,但每个阶段都有对输入图像的访问权限。与最近的修复方法[ 70、88 ]类似,我们在输入图像上采用多块分层结构,将图像分割成不重叠的块:第一阶段4块,第二阶段2块,最后一阶段的原始图像,如图2所示。
在任一给定阶段S,所提出的模型不是直接预测复原图像,而是预测残差图像
,将退化的输入图像I加入残差图像
,得到:
。我们使用如下损失函数对MPRNet进行端到端的优化:
式中Y表示标签图像(没有受损的图像),为Charbonnier损失函数[ 12 ]:
固定ε,所有实验经验设定为。此外,
是边缘损失,定义为:
式中:∆表示Laplacian算子(在代码中是使用原图减去高斯模糊然后下采样之后再上采样再高斯模糊的结果,两者都是求高频信息但两者一样吗?)。方程中的参数λ。( 1 )控制了两个损失项的相对重要性,如文献[ 37 ]中设定的0.05。接下来,我们描述了我们方法的每个关键元素。(通过损失函数中加入高频信息的考虑,可以帮助网络更好地保持图像的结构和纹理,从而生成更加真实和清晰的修复图像。)
互补特征处理
现有的用于图像复原的单级CNN通常使用以下架构设计之一:1 ) .编码器-解码器,或2 ) .单尺度特征管道。编码器-解码器网络[ 7、13、43、65]首先逐步将输入映射到低分辨率表示,然后逐步应用反向映射恢复原始分辨率。虽然这些模型有效地编码了多尺度信息,但由于重复使用下采样操作,它们容易牺牲空间细节。相比之下,在单尺度特征管道上运行的方法在生成具有精细空间细节[ 6、18、93、97]的图像时是可靠的。然而,由于有限的感受野,它们的输出在语义上是不稳定的。这表明了上述架构设计选择的固有局限性,即能够生成空间精确或上下文可靠的输出,但不能同时生成这两种输出。为了利用这两种设计的优点,我们提出了一个多级框架,其中早期阶段包含编码器-解码器网络,最后阶段使用一个在原始输入分辨率上运行的网络。
编码器-解码器子网络。图3a展示了我们提出的基于标准U - Net [ 65 ]的编码器-解码器子网络。首先,我们在每个尺度( CABs见图3(b))上添加通道注意力模块( Channel Attention Blocks,CABs ) [ 95 ]来提取特征。其次,U - Net跳跃连接处的特征图也用CAB处理。最后,在解码器中使用双线性插值后再使用卷积层(相当于放大之后再卷积),而不是使用转置卷积来提高特征的空间分辨率。这有助于减少由于转置卷积而导致的输出图像中的棋盘格伪影[ 55 ]。
原始分辨率子网络。为了保留从输入图像到输出图像的细节信息,我们在最后一个阶段(见图2)中引入了原始分辨率子网络( ORSNet )。ORSNet不使用任何下采样操作,生成空间丰富的高分辨率特征。它由多个原始分辨率块( ORB )组成,每个ORB进一步包含CAB。ORB示意图如图3b所示。
跨阶段特征融合
在我们的框架中,我们在两个编码器-解码器(见图3c)之间以及编码器-解码器与ORSNet (见图3d)之间引入了CSFF模块。注意,来自一个阶段的特征在传播到下一个阶段进行聚合之前先用1 × 1卷积提取特征。提出的CSFF具有几个优点。首先,它使网络在编码器-解码器中由于重复使用上、下采样操作而减少信息损失。第二,一个阶段的多尺度特征有助于丰富下一个阶段的特征。第三,网络优化过程变得更加稳定,因为它缓解了信息的流动,从而允许我们在整体架构中添加几个阶段。
监督注意模块
最近的用于图像恢复的多级网络[ 70、88 ]直接在每个阶段预测一幅图像,然后将其传递到下一个连续的阶段。相反,我们在每两个阶段之间引入一个有监督的注意力模块,这有助于实现显著的性能增益。SAM的示意图如图4所示,其贡献是双重的。首先,它为每个阶段的渐进图像复原提供了有用的监督信号。其次,在局部监督预测的帮助下,我们生成注意力图来抑制当前阶段信息较少的特征,只允许有用的特征传播到下一阶段。
如图4所示,SAM利用前一阶段的输入特征,首先通过简单的1 × 1卷积生成残差图像
,其中H × W表示空间维度,C表示通道数。将残差图像添加到受损的输入图像I中,得到修复图像。对于这个预测图像
,我们提供了真值图像的显式监督。然后,使用1 × 1卷积和sigmoid激活从图像
中生成超像素注意力掩码
。然后利用这些掩码对变换后的局部特征
(经过1 × 1卷积后得到)进行重新标定,得到注意力引导的特征,并将其添加到身份映射路径中。最后,将SAM产生的注意力增强特征表示传递到下一阶段进行进一步处理。
实验和分析
我们在10个不同的数据集上评估了我们的方法,包括( a )图像去雨,( b )图像去模糊,( c )图像去噪。
数据集和评估协议
使用PSNR和SSIM [ 76 ]指标进行定量比较。与文献[ 7 ]一样,我们通过将PSNR转换为RMSE
和SSIM转换为DSSIM ( DSSIM = ( 1-SSIM )/ 2) 来报告(括号内)每个方法相对于最佳方法的误差减少。用于训练和测试的数据集总结在表1中,下面进行描述。
图像去雨。使用与最新最好的图像去雨方法[ 37 ]相同的实验设置,我们在来自多个数据集[ 23、48、81、89、90]的13712张干净雨图像对上训练了我们的模型,如表1所示。利用这个单一训练好的模型,我们对各种测试集进行了评估,包括Rain100H [ 81 ]、Rain100L [ 81 ]、Test100 [ 90 ]、Test2800 [ 23 ]和Test1200 [ 89 ]。
图像去模糊。与[ 70,88,43,71]一样,我们使用GoPro [ 53 ]数据集,其中包含2,103(指的是有103对,一张模糊的图像,一张对应的清晰的图像组成一对)对图像对用于训练,1,111(111张模糊的图像用于评估)对用于评估。此外,为了证明模型的可推广性,我们将Go Pro训练好的模型直接应用于HIDE [ 69 ]和RealBlur [ 64 ]数据集的测试图像。HIDE数据集专门用于人体感知运动去模糊,其测试集包含2 025张图像。在综合生成GoPro和HIDE数据集的同时,在真实世界条件下捕获了RealBlur数据集的图像对。RealBlur数据集有两个子集:( 1 )通过相机JPEG输出形成RealBlur - J;( 2 )对RAW图像进行白平衡、去马赛克和去噪操作,离线生成RealBlur - R。
图像去噪。为了训练我们的模型用于图像去噪任务,我们使用了SIDD数据集的320张高分辨率图像[ 1 ]。在来自SIDD数据集[ 1 ]的1,280个验证补丁和来自DND基准数据集[ 60 ]的1,000个补丁上进行评估。这些测试块由原始作者从全分辨率图像中提取。SIDD和DND数据集均由真实图像组成。
实现细节
我们的MPRNet是端到端可训练的,不需要预训练。我们针对三种不同的任务分别训练模型。我们在编码器-解码器的每个尺度上使用2个CAB,对于下采样我们使用2 × 2的最大池化步幅2。在最后一个阶段,我们使用包含3个ORB的ORSNet,每个ORB进一步使用8个CAB。根据任务的复杂程度,我们将通道数设置为40个用于去雨,80个用于去噪,96个用于去模糊。网络在批大小为16的256 × 256块上训练,迭代次数为4 × 105。对于数据增强,随机施加水平和垂直翻转。我们使用Adam优化器[ 41 ],初始学习速率为2 × 10-4,使用余弦退火策略[ 50 ]将初始学习速率稳定地降低到1 × 10 - 6。
图像去雨结果
对于图像去雨任务,与之前的工作一致[ 37 ],我们使用Y通道(在YCbCr颜色空间中)计算图像质量分数。表2显示,我们的方法在所有5个数据集上都取得了更好的PSNR / SSIM评分,显著地提升了当前最好的方法。与最近最好的算法MSPFN [ 37 ]相比,我们获得了1.98 dB (所有数据集的平均值)的性能增益,表明误差减少了20 %。在部分数据集上的提升幅度高达4 d B,如Rain100L [ 81 ]。此外,我们的模型比MSPFN [ 37 ]少了3.7 倍个参数,而速度提高了2.4 倍。
图5显示了对挑战性图像的视觉比较。我们的MPRNet可以有效地去除不同方向和大小的雨线,并生成视觉上令人愉快和更像于标签的图像。相比之下,其他方法折中了结构内容(第一行),引入了伪影(第二行),没有完全去除雨线(第三行)。
图像去模糊结果
我们在表3中报告了在合成的GoPro [ 53 ]和HIDE [ 69 ]数据集上评估图像去模糊方法的性能。总的来说,我们的模型与其他算法相比表现良好。与之前最好的方法[ 70 ]相比,我们的方法在GoPro [ 53 ]数据集上的PSNR和SSIM分别提高了9 %和21 %,在HIDE数据集上的误差分别降低了11 %和13 % [ 69 ]。值得注意的是,我们的网络仅在GoPro数据集上训练,但在HIDE数据集上取得了最先进的结果( + 0.98 d B ),从而展示了其强大的泛化能力。
我们在最近的一个RealBlur [ 64 ]数据集的真实世界图像上评估了我们的MPRNet在两个实验设置下的表现:1 ) .直接在RealBlur (测试对真实图像的泛化能力)上应用GoPro训练的模型,2 ) .在RealBlur数据上进行训练和测试。表4为实验结果。对于设置1,我们的MPRNet比DMPHN算法在RealBlur - R子集上获得了0.29 d B的性能增益,在RealBlur - J子集上获得了0.28 d B的性能增益[ 88 ]。对于设置2,我们观察到类似的趋势,在RealBlur - R和RealBlur - J上,我们在SRN [ 71 ]上的增益分别为0.66 dB和0.38 dB。
图6显示了一些经过评估的去模糊图像。总体而言,我们的模型恢复的图像比其他模型恢复的图像更清晰,更接近真实图像。
图像去噪结果
在表5中,我们报告了几种图像去噪方法在SIDD [ 1 ]和DND [ 60 ]数据集上的PSNR / SSIM得分。我们的方法获得了相当大的增益,在SIDD上比CycleISP [ 86 ]高0.19 dB,在DND上比SADNet [ 11 ]高0.21 dB。注意到DND数据集不包含任何训练图像,即公开发布的完整数据集只是一个测试集。在DND基准测试集上使用我们的SIDD训练模型的实验结果表明,我们的模型对不同的图像域都有很好的泛化能力。
图7为可视化结果。我们的方法能够去除真实噪声,同时保留图像的结构和纹理细节。相比之下,其他方法恢复的图像要么包含过于平滑的内容,要么包含纹理粗糙的伪影。
在这里,我们给出了消融实验来分析我们模型的每个组成部分的贡献。在GoPro数据集上进行评估[ 53 ],使用在大小为128 × 128的图像块上训练的去模糊模型进行105次迭代,结果如表6所示。
阶段数。我们的模型随着阶段数的增加表现出更好的性能,验证了我们多阶段设计的有效性。
子网络的选择。由于我们的模型每个阶段可以采用不同的子网络设计,因此我们测试了不同的选项。我们表明,在前阶段使用编码器-解码器,在最后阶段使用ORSNet,与对所有阶段( U-Net + UNet为29.4 d B , ORSNet + ORSNet为29.53 d B)采用相同的设计相比,可以提高性能( 29.7 dB )。
SAM和CSFF .我们将提出的监督注意力模块和跨阶段特征融合机制从最终的模型中移除,证明了其有效性。从表6可以看出,去除SAM后,PSNR从30.49 d B下降到30.07 d B,去除CSFF后,PSNR从30.49 d B下降到30.31 d B。去除这两个成分会使性能从30.49 dB大幅下降到29.86 dB。
高效的图像修原
CNN模型通常表现出在精度和计算效率之间的权衡。为了追求更高的精度,往往开发更深更复杂的模型。尽管大模型往往比小模型表现更好,但计算成本可能高得令人望而却步。因此,开发资源高效的图像恢复模型是非常有意义的。一种解决方法是每次改变目标系统时,通过调整同一网络的容量来训练该网络。然而,它是繁琐的,而且往往是不可行的。更可取的方法是有一个单一的网络,可以( a )对计算有效的系统进行早期预测,( b )进行后期预测,以获得较高的准确性。多阶段修复模型自然地提供了这样的功能。
表7报告了我们的多阶段方法的阶段结果。我们的MPRNet在每个阶段都展示了具有竞争力的恢复性能。值得注意的是,我们的stage - 1模型轻量、快速,并且比其他复杂的算法如SRN [ 71 ]和DeblurGANv2 [ 43 ]产生更好的结果。同样,与最近的方法DMPHN [ 88 ]相比,我们的stage - 2模型的PSNR增益为0.51 dB,同时具有更高的资源效率( 少于2 倍参数和13 倍的速度)。
在这项工作中,我们提出了一种用于图像修复的多阶段架构,通过在每个阶段加入入监督来逐步改善受损的输入。我们为我们的设计制定了指导原则,要求在多个阶段进行互补的特征处理,并在这些阶段之间进行灵活的信息交换。为此,我们提出了上下文丰富且空间准确的阶段,以统一的方式编码多样化的特征集合。为了保证交互阶段之间的协同性,我们提出了跨阶段的特征融合和注意力引导的输出交换。我们的模型在大量的基准数据集上取得了显著的性能提升。此外,我们的模型在模型大小方面是轻量级的,在运行时间方面是高效的,这对于资源有限的设备来说是非常有意义的。