Multi-Stage Progressive Image Restoration
写作思路
- 之前的人做了什么
- 用了什么方法,做了什么研究
- 这些方法存在什么问题
- 针对问题,本文采用什么方法做研究,使用什么方法(数据集)来验证效果
- 提出每个模块:以什么为基础,提出了什么结构
摘要
设计了一种架构:balance spatial details and high-level contextualized information
- 多阶段架构 (使用encdec架构+高分辨率分支)
- 监督 attention 模块
- 跨阶段间特征融合模块(顺序+横向连接)
Introduction
图像恢复任务简介
-
Image restoration is the task of recovering a clean image from its degraded version.
-
Typical examples of degradation include noise, blur, rain, haze, etc.
-
It is a highly ill-posed problem as there exist infinite feasible solutions.
之前图像恢复的方法
-
手工设计特征 - 基于先验的方法难以推广
-
深度学习方法 - 大多数方法都是单阶段的 (但是多阶段更有效)
-
分析现有的多阶段的方法的架构存在问题
- encoderdeocoder 提取上下文信息有效但保存空间细节差 - 上下采样来提高感受野(卷积的区域扩大,更好的提取上下文特征),但不断的上下采样使其空间细节不断降低
- encoderdeocoder 提取上下文信息有效但保存空间细节差 - 上下采样来提高感受野(卷积的区域扩大,更好的提取上下文特征),但不断的上下采样使其空间细节不断降低
-
single-scale pipeline (没有过多的使用上下采样)保留空间细节 输出不可靠(不能很好的提取上下文信息)
-
分析结果
- 我们需要结合以上两种架构的优点
- naively passing the output of one stage to the next stage yields suboptimal results (不是最好的结果).
- it is important to provide ground-truth supervision at each stage for progressive restoration(使每个阶段专注于该阶段不同尺度/分辨率的噪音程度恢复-细化每个阶段的恢复过程).
- 需要有一种将多阶段中间特征融合传递的机制,保留encdec分支的上下文特征
-
基于以上现状,做了以下工作
- 提出多阶段架构,早期阶段encoderdecoder(丰富上下文信息), 晚阶段使用原始分辨率图像融合的模块(补充空间细节)
- 提出SAM模块 利用前一阶段的预测结果来生成attention,监督学习专注该阶段的恢复,利用当前阶段的预测去细化特征转递给下一阶段
- 跨阶段的特征融合方法-将不同阶段的多尺度特征更好的融合,且简化阶段间的信息流
- 通过在 10 个合成和真实数据集上设置新的最先进技术,用于各种恢复任务
相关工作
数据增强()
- 从Beta分布中生成随机的混合系数 λ(lam)作为混合的权重。
- 随机排列输入数据集,得到另一组具有相同数量样本的索引。
- 使用λ将原始图像和相应的随机排列图像按照一定权重混合,从而生成新的ground truth图像。
- 同样,使用相同的λ将原始噪声图像和相应的随机排列噪声图像按照一定权重混合,从而生成新的噪声图像
def aug(self, rgb_gt, rgb_noisy):
bs = rgb_gt.size(0)
indices = torch.randperm(bs)
rgb_gt2 = rgb_gt[indices]
rgb_noisy2 = rgb_noisy[indices]
lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
return rgb_gt, rgb_noisy
损失函数
- 这里是通过求拉普拉斯金字塔来等价拉普拉斯算子的,拉普拉斯算子(二阶微分)应用于图像提取图像的边缘和纹理等高频特征。与拉普拉斯金字塔得出的类似
multi-stage的架构 采用分辨率由细到粗
- 在从粗到细的方案下,由于滤波器尺寸较大,大多数网络使用大量训练参数。因此,多尺度和尺度循环方法会导致运行时间昂贵(见,并且很难提高去模糊质量。
- 简单地通过更精细的尺度级别增加模型深度并不能提高去模糊的质量。
- 改为从精到粗,每个阶段预测残差,将图像分patch, 对每个patch从细到粗的处理最后再合并
Original Resolution Subnetwork
-
卷积提取特征+CAB通道注意力(为了保留空间细节不使用空间注意力)
-
使用原始分辨率的图像丰富空间细节 不断使用原始分辨率的图像跳越连接 保留高分辨率的细节
Cross-stage Feature Fusion
-
提出一种低级上下文信息与高级上下文信息融合的范式,值得思考以往的网络架构中如何融合低级与高级的语义信息。
-
操作:Note that the features from one stage are first refined with 1 × 1 convolutions before propagating them to the next stage for aggregation.
-
意义
-
encdec每层的特征也传递,减少因为频繁上下采样造成信息丢失
-
前一个阶段的多尺度特征会丰富到下一个阶段的特征
-
网络优化过程变得更加稳定,因为它简化了信息流,从而允许我们在整体架构中添加几个阶段
class ORSNet(nn.Module): def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab): super(ORSNet, self).__init__() self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.up_enc1 = UpSample(n_feat, scale_unetfeats) self.up_dec1 = UpSample(n_feat, scale_unetfeats) self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) #融合三个尺度的特征 self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) def forward(self, x, encoder_outs, decoder_outs): x = self.orb1(x) x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) x = self.orb2(x) x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1])) x = self.orb3(x) x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2])) return x
Supervised Attention Module
- 监测每个阶段图像恢复的效果+先验
- 通过生成注意力图,使得下一个阶段更关注有用的信息(放大该阶段修复的特征)
- Fin ∈ RH×W ×C of the earlier stage and first generates a residual image RS ∈ RH×W ×3(反卷积的特征) with a simple 1 × 1 convolution通过1*1降维
- 残差+受损图像生成恢复图像用于监督损失
- Next, perpixel attention masks M ∈ RH×W ×C are generated from the image XS using a 1×1 convolution followed by the sigmoid activation
class SAM(nn.Module):
def __init__(self, n_feat, kernel_size, bias):
super(SAM, self).__init__()
self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
self.conv3 = conv(3, n_feat, kernel_size, bias=bias)
def forward(self, x, x_img):
x1 = self.conv1(x)
img = self.conv2(x) + x_img #X_S恢复图像
x2 = torch.sigmoid(self.conv3(img)) #生成注意图
x1 = x1*x2
x1 = x1+x
return x1, img
实验
数据集
metric PSNR and SSIM metrics
- Deraining
- Deblurring
- GoPro trained model and directly apply it on the test images of the HIDE and RealBlur datasets.
- While the GoPro and HIDE datasets are synthetically generated
- the image pairs of RealBlur dataset are captured in real-world conditions
- Denoising
- Evaluation is conducted on 1,280 validation patches from the SIDD dataset [1] and 1,000 patches from the DND benchmark dataset
- patches are extracted from the full resolution images
实现配置
- 端到端 、不需要预训练、为每个任务训练单独的模型
- 网络配置
- 2 CABs at each scale of the encoder-decoder
- downsampling we use 2×2 max-pooling with stride 2
- In the last stage, we employ ORSNet that contains 3 ORBs, each of which further uses 8 CABs
- number of channels to 40 for deraining, 80 for denoising, and 96 for deblurring
- 训练配置
- trained on 256×256 patches with a batch size of 16 for 4×105 iterations.
- 数据增强
- horizontal and vertical flips are randomly applied
- 优化器
- Adam optimizer with the initial learning rate of 2×10−4
实验结果
- 评价指标
- PSNR(Peak Signal-to-Noise Ratio):
- PSNR是一种用于测量图像失真的指标,通常用于比较原始图像和经过处理(压缩、降噪等)的图像之间的质量。它以分贝(dB)为单位表示,计算方式如下: PSNR = 10 * log10((MAX^2) / MSE) 其中,MAX 是图像像素值的最大可能值(通常为255对于8位图像),而MSE(Mean Squared Error)是原始图像与处理后图像之间的均方误差。PSNR值越高,表示图像质量越好,失真越小。
- PSNR对于噪声较小、适用于线性变换的图像处理任务非常有用,但它不总是能够捕捉到人眼对图像质量的感知。因此,在某些情况下,它可能与人类主观感知不一致。
- SSIM(Structural Similarity Index):
- SSIM是一种更复杂的图像质量度量指标,它考虑了亮度、对比度和结构三个方面的信息,从而更全面地评估了图像之间的相似性。SSIM的计算方式包括以下三个组成部分:
- 亮度相似性(Luminance Similarity):评估了亮度分布的相似性。
- 对比度相似性(Contrast Similarity):评估了图像对比度的相似性。
- 结构相似性(Structure Similarity):评估了图像结构的相似性。
- SSIM的结果在-1到1之间,1表示两幅图像完全相同,而较低的值表示差异越大。通常,SSIM值越接近1,表示图像质量越高,与人眼感知更相符。
- SSIM在评估图像质量时比PSNR更具代表性,因为它考虑了更多的人眼感知因素,例如对比度和结构信息。
- SSIM是一种更复杂的图像质量度量指标,它考虑了亮度、对比度和结构三个方面的信息,从而更全面地评估了图像之间的相似性。SSIM的计算方式包括以下三个组成部分:
意义
- 计算资源与精度的平衡,需要根据系统的不同资源状况动态调整网络规模,显然多阶段的模型有着这样的好处