【图像去模糊】SRN-DeblurNet 深入浅出

0 前言

图像去模糊是数字图像处理领域中不可或缺的任务,是计算机视觉和图像处理领域的一个重要问题,其目标是从模糊图像中恢复清晰的图像。本文将介绍用于图像去模糊任务的 SRN-DeblurNet 模型,尝试用通俗的语言以及过程可视化,帮助读者深入浅出地理解 SRN-DeblurNet 的完整结构。在文章的最后,我们将手把手教你实现 SRN-DeblurNet 的代码。以下是文章的结构:

  1. 任务介绍
  2. 模型介绍
  3. 代码实现

1 任务介绍

1.1 图像为什么会模糊?

试想这么一个场景:你和你的女朋友(男朋友)现在在你们梦寐以求的旅游地——特卡波湖,正打算拍摄唯美的星空。在漫天繁星下平静湖畔旁,你的女朋友(男朋友)已经摆好了姿势,瞪大了眼睛兴致勃勃地等你拍照,你摁下了刚买的微单相机的快门,咔嚓——。你充满自信地向女朋友(男朋友)比了个耶的手势,招呼她(她)过来和你一起看看成片。哦吼!结果发现星星因为你的手抖在天空上形成了拖影。你的女朋友(男朋友)已经开始怒发冲冠,你见状赶紧安慰道:“没关系宝贝,这样子看好像流星哦~”,结果她(她)说:“嗯嗯,我的脸确实糊得像被流星撞了”。

1.1.1 模糊的原因

这种场景即使对于拍照老鸟应该也是经历过的事情,图片模糊是再常见不过的事情了。而在本文中,我们将聚焦于由相机拍摄引起的模糊。相机拍摄导致的模糊有多种原因,其中包括光圈、快门速度、感光度(对于数码相机)、对焦等多个参数的相互影响。在上面的场景中,模糊的原因可能有以下几种:

  1. 相机抖动:手持拍摄时的相机抖动会导致图像模糊。尤其是在夜晚,对于拍照经验不足的新手来说,很容易设置了过长的快门时间,导致即便是轻微的抖动也可能在低快门速度下引入模糊。
  2. 场景运动:这张照片的不如人意也许不只怪你一个,也许你的女朋友(男朋友)也要背锅(不过不建议你和她讲,除非你们关系特别好)。当拍摄对象是可以移动的生物或物体时,其在快门打开期间移动,会导致图像模糊。这被称为拍摄过程中的运动模糊。
  3. 景深问题:景深不足或过大可能导致部分图像模糊。也许你为了在夜晚拍出更好的效果而设置了较大的光圈,可能会导致很小的景深(原因可以类比小孔成像的孔径变大了),不在景深里的物体会失焦,因此你也就很难把女朋友(男朋友)和星空同时拍的清晰。

在这里插入图片描述

1.1.2 模糊的种类

除了上述三个常见的原因之外,造成图像模糊的原因还有镜头缺陷、后期处理失误等。这些原因将会导致不同样式的模糊,模糊的种类主要包括:

  1. 运动模糊:在相对较长的曝光期间,相机或被拍摄物体的相对运动导致图像产生模糊
  2. 失焦模糊:目标位置与相机焦距不对应,相机无法对焦到目标时,会拍摄到离焦模糊的图片
  3. 高斯模糊:高斯卷积得到的模糊图像
  4. 混合模糊:当一个图片同时被多种因素影响时,造成的模糊就是混合模糊,比如相机拍摄在失焦状态下的高速运动物体时,得到的模糊会混杂运动模糊以及失焦模糊

在这里插入图片描述

研究图像去模糊具有广泛的意义,不仅在于能够改善图片质量,让您不会因为给女朋友(男朋友)拍了丑照而丧失家庭地位,而且还对图像相关的上游任务,如三维重建、照片超分辨率、图像识别等,产生积极影响。这一领域的研究不仅对传统相机制造商重要,也对具有较小传感器的新兴的手机摄影技术有着重要意义。

1.2 去模糊的难度在哪?

看到这里,我想你应该对图像模糊有了最基本的认识,那么就让我们开始着手于去模糊吧!不过,开始了解去模糊的具体技术的同时,我们不妨先理解一下去模糊任务的难点、背景知识以及整个任务的概述。我们将介绍两个去模糊的方法分类:非盲去卷积与盲去卷积。

1.2.1 非盲去卷积

非盲去卷积,其实就是在已知成像 PSF 的情况下,对模糊图像进行去模糊的操作。可是,PSF 又是什么呢?要彻底理解 PSF,就需要学习傅里叶变换的知识,但想在一篇文章中讲清这么多是相当困难的,我们不妨直观地理解它。下图©是运动模糊所对应的 PSF,在这里,我们假设相机或被拍摄物体的相对运动轨迹呈直线。

在这里插入图片描述

可以看出,图片施加了该运动模糊对应的 PSF 后,向着 PSF 直线方向产生了运动模糊。这是因为 PSF 可以理解为,相机中原本只应该呈现为一个点的光,呈现成了一条直线。于是相对而言,对于感光元件的某一个像素,其接受到的光线为原先的一条线段的光线。换个更加直观的角度思考,在你脑海里将整个图像变得透明,然后尝试让图像在你脑子里沿着 PSF 的方向移动产生重影,就得到了运动模糊。看吧,是不是也没有那么难理解!

在已知了 PSF 以及模糊图像的情况下,在理论上我们将很容易得到清晰的图像。这里列出一些计算公式,仅用于解释这个过程是非常简单的。其中, G ( u , v ) G(u,v) G(u,v)为模糊图像, F ( u , v ) F(u,v) F(u,v)原始的清晰图像,而 H ( u , v ) H(u,v) H(u,v)则为由 PSF 得到的退化函数。值得一提的是,公式均作用于频率域( u , v u,v u,v为频率域坐标),在频率域,卷积(PSF 的作用形式)相当于乘法,而去卷积(去模糊)相当于频域上进行除法。

退化函数作用于原图像: G ( u , v ) = H ( u , v ) F ( u , v ) 图像去模糊: F ( u , v ) = G ( u , v ) / H ( u , v ) 退化函数作用于原图像: G(u,v)=H(u,v)F(u,v) \\ 图像去模糊: F(u,v)=G(u,v)/H(u,v) 退化函数作用于原图像:G(u,v)=H(u,v)F(u,v)图像去模糊:F(u,v)=G(u,v)/H(u,v)

但实际情况可完全不是这样,因为我们没有考虑噪声。添加噪声后的退化函数作用方程长这样: G ( u , v ) = H ( u , v ) F ( u , v ) + n o i s e G(u,v)=H(u,v)F(u,v)+noise G(u,v)=H(u,v)F(u,v)+noise。这时候,我们再简单粗暴地进行除法,将会放大高频噪声,导致图像输出充满了噪声。对此,领域内已经有了相当多的研究,诸如 维纳滤波Richardson-Lucy 等方法,都能在图像有噪声的情况下取得较好的效果。

在这里插入图片描述

1.2.2 盲去卷积

盲去卷积。既然卷积方法如此之好,能把糊成这样的图片都给整清晰了,那是不是什么图片都能完美去模糊了叻?显然不行,其中的关键就在于我们需要提前知道 PSF,但这是相当困难的,因为 PSF 与模糊息息相关,而实际的模糊不仅来源于相机镜头的缺陷,还有可能来自变幻莫测的物体移动等等。那么在不知道 PSF 的情况下,我们如何进行图片去模糊?答案是盲去卷积。

传统的盲去卷积依旧揪着 PSF 不放,视图通过估计的模糊核来估计清晰图像。这听起来就难得不得了。目前的常见方法都在尝试使用一些先验知识,通过启发式参数调整、复杂的数学建模以及一系列的概率求解,来对清晰图像进行重建求解。但也有部分研究取得了不错的效果,比如 非负矩阵分解MAF,但总体而言,其应用仍旧存在数据局限性、噪声敏感性以及非唯一性等问题。

当一个看似不可能完成的任务出现了的时候,别急着放弃,这不还有深度学习呢!由于深度学习强大的表征学习能力、模型的灵活性以及计算能力的提升,其已经广泛应用于语言领域、图像领域等,在图像去模糊领域,深度学习也已经取得了显著的成果。目前已有的基于深度学习的图像去模糊模型已经有 DeblurGANSRN-DeblurNetEDVRPSS-NSC 等等,本文将介绍 SRN-DeblurNet,尝试将深度学习在图像去模糊领域的作用机制直观并通俗易懂地呈现出来。

2 模型介绍

2.1 SRN-DeblurNet 概述

正如前文介绍的,基于传统算法的去模糊方法需要深厚的数学知识,理解起来非常困难,而基于深度学习的去模糊方法大放异彩。本文介绍的 SRN-DeblurNet(CVPR2018)方法便是一种基于深度学习盲去模糊方法,它沿用了去模糊领域广泛应用的从粗到细(coarse-to-fine)的方案,提出了一个新的用于去模糊任务的尺度循环网络(Scale-recurrent Network),采用尺度训练方法,使用了编码器-解码器,ResBlock 网络等,该方法有两大突出特点:

  • SRN-DeblurNet 是第一篇将循环神经网络 RNN(Recurrent Neural Network)引入去模糊任务,而此前基于深度学习的去模糊领域通常使用 CNN(卷积神经网络),该文章的引用已达 1090 次(Google scholar),在基于深度学习的去模糊领域中具有开创性的意义。
  • SRN-DeblurNet 相比于同期其它的基于深度学习的方法,它的网络结构更简单,参数数量更少,训练更高效、容易;而且该网络的去模糊效果在相关邻域其它论文中得到了一致的认可。

以下将对 SRN-DeblurNet 网络的特点展开叙述,共分为三个部分:

  • SRN 多尺度循环网络
  • 具有残差块的编码器解码器结构
  • 损失函数

2.2 SRN 多尺度循环网络

在这里插入图片描述

SRN 多尺度循环网络是 SRN-DeblurNet 的基础架构,理解它有助于帮助我们从宏观上理解整个网络,但这里提到的一些概念可能暂时只是一个上层的抽象,读者只需理解其思想,而不用深入了解其具体的实现,这些会在后续部分叙述。

2.2.1 多尺度

SRN 的全称是 Scale-recurrent Network,多尺度便是多 Scale。我们不妨以人的视角理解多尺度。对于人类而言,如何识别出一只猫?有的人依靠认耳朵,来识别出猫,但就像下图的对比中,左边的老虎的耳朵与猫咪的耳朵并无二致。这时又会有人说,我还可以看它的体型体态,来识别出这是老虎还是猫。

在这里插入图片描述

没错,这样当然是可以的,但对于计算机而言,同时关注到这两个信息是困难的,因为他们处于不同的尺度中。若我们固定计算机识别图片时的感受野,它过小时可能难以识别到较大尺度的信息,过大则反之。并且,很多时候我们的数据集也会有尺寸缩放问题。这时候,我们就需要多尺度的网络来识别、捕捉不同尺度的信息。而 SRN 将这种操作组织成了从粗糙到精细的过程,类似金字塔的结构,最终得到全分辨率的清晰图像。这种多尺度操作能帮助模型更全面地理解图像,尤其在处理尺寸变化的数据集时更具鲁棒性。

2.2.2 循环神经网络

在上一节中,我们介绍了 SRN 的多尺度思想。那么如何将多尺度的信息糅合在一起呢?这就需要循环神经网络了。对于普通的神经网络,他们大多只能处理一个一个的输入,并且前一个输入和后一个输入是完全没有联系的,而循环神经网络能够将先前的信息利用起来。举个词性分析的例子,我现在需要获取“我 让 坐”的词性,假如我们单独看“坐”,可能会觉得这是一个名词,但若放眼整个句子,不难发现“坐”是“坐位”的意思,也就应该看作是名词了。循环神经网络能够将先前的信息一并输入到网络中,便能够解决这个问题。而在 SRN 网络中,笔者认为在两个地方都运用了循环神经网络。

第一,将上一层获取到的清晰图片放到当前层,与当前尺度的模糊图像一并作为输入。

在这里插入图片描述

第二,将上一层中间提取的特征与当前层获取到的特征融合(如下图所示),再使用解码器进行解码(在后续部分会解释)。

在这里插入图片描述

上图中蓝色块为 ConvLSTM(卷积长短期记忆网络),ConvLSTM 是一种专为处理时空序列数据设计的神经网络结构,它结合了卷积神经网络(CNN)和长短期记忆网络(LSTM)的优点以更有效地捕捉时空相关性。与传统的 LSTM 不同,ConvLSTM 的计算单元包括卷积层,它在时间上处理序列数据,并且在空间上也能够利用卷积操作处理图像或多维数据。此外,ConvLSTM 通过卷积运算实现了参数共享,能够在时空维度上并行计算,减少了参数数量并提高了计算效率。这使得 ConvLSTM 非常适合处理大规模时空数据,并能够更好地捕获时空信息。因此,ConvLSTM 在需要考虑时空依赖关系的任务中,如视频数据处理或图像序列预测,表现更为出色。

2.2.3 跨尺度共享网络权重

对于循环神经网络,我们常用的做法是共享网络权重,在 SRN 中,即共享网络的参数。原因有三点:

  1. 由于我们在每个尺度下的任务是一样的,所以为了让网络也在不同尺度下做同样的事情,对尺度方面的数据增强,我们共享网络的参数。
  2. 若每个尺度下的网络参数不限制为一致,则很容易导致解空间不受限制,模型将会产生过拟合现象,即“复习考试选择背答案”的现象。
  3. 降低了模型训练的难度,且更加稳定。

2.2.4 双线性插值上采样

在 2.2.2 节的第一点中,我们提到了将上一层获取到的清晰图片放到当前层的操作。但是,尺度不一样的图片怎么合在一起呢?SRN 选择的是使用双线性插值进行上采样。插值指利用已知的点来“猜”未知的点,线性插值是用线性函数拟合两个点之间的点,那么双线性插值则是在两个维度均使用线性插值,统共用四个点来拟合他们之间的点。如下图所示。

在这里插入图片描述

2.2.5 公式

最后,我们尝试来理解一下一条抽象的公式:

I i , h i = N e t S R ( B i , I i + 1 ↑ , h i + 1 ↑ ; θ S R ) I^i,h^i=Net_{SR}(B^i,I^{i+1↑},h^{i+1↑};\theta_{SR}) Ii,hi=NetSR(Bi,Ii+1,hi+1;θSR)

下面是公式符号的解释:

  • i i i:不同的尺度, i = 1 , 2 , 3 i=1,2,3 i=1,2,3,代表三个不同尺度, i = 1 i=1 i=1代表最后输出的分辨率最大的尺度
  • I i I^i Ii:当前尺度下输出的清晰图片
  • h i h^i hi:当前尺度下捕捉到的隐藏状态特征
  • N e t S R Net_{SR} NetSR:单个尺度下的网络,用模糊图片生成清晰图片。具体架构会在下一部分介绍。
  • B i B^i Bi i i i尺度下输入的模糊图片
  • θ S R \theta_{SR} θSR:模型训练的参数
  • ↑ ↑ :上采样操作

从公式中能够更加直接地理解整个 SRN-DeblurNet 网络的思想。

  1. 上一层的输出 I i + 1 ↑ I^{i+1↑} Ii+1也要作为当前层的输入
  2. 模型训练参数 θ S R \theta_{SR} θSR跨尺度共享
  3. 上一层中间提取的特征 h i + 1 ↑ h^{i+1↑} hi+1与当前层获取到的特征融合得到当前层特征 h i h^{i} hi

2.3 具有残差块的编码器解码器结构

在前面我们详细介绍了 SRN-DeblurNet 的基础架构——SRN 多尺度循环网络,其中使用了 ConvLSTM 单元来捕获图像序列中的时空相关性。此外,编码器解码器结构对于图像到图像的生成任务也十分关键,然而,编码器解码器结构并不适合直接用于本文的任务,原因有以下三点:

  1. 对于去模糊的任务,需要很大的感受野才能处理严重模糊的情况,这意味着需要增加网络的层数,伴随而来的是大量的参数与中间特征;
  2. 在编码器解码器结构中添加过多的卷积操作会使得网络的效率大大降低;
  3. SRN-DeblurNet 采用了循环神经网络。

那么 SRN-DeblurNet 最终是如何让编码器解码器结构能适配本文的任务呢?原来它在编码器和解码器中间插入了残差块,通过具有残差块的编码器解码器结构,SRN 网络能够更充分地利用多尺度信息,有效实现去模糊任务。下面将详细介绍该网络结构中的关键组成部分。

2.3.1 编码器解码器结构

编码器解码器结构是如 U-net 的对称 CNN 结构(如下图所示),该结构在编码器中将输入数据逐步转换为具有较小空间尺寸和更多通道数的特征图,然后在解码器中反向执行编码器的过程将特征图逐步转换为输出。

在这里插入图片描述

2.3.2 卷积神经网络 CNN

卷积神经网络(CNN)和神经网络有什么不同呢?

卷积神经网络同样是层级结构,但层的功能和形式都有一定变化,是传统神经网络的一个改进,输入为复杂庞大的图像数据而不是较为简单的向量。

一个卷积神经网络主要包括以下 5 层:

  1. 数据输入层:对输入图像进行归一化等预处理
  2. 卷积层:"卷积神经网络"名字的由来,也是最重要的层次,使用卷积核进行特征提取和特征映射
  3. 激活层:用激活函数对卷积层输出结果做非线性映射
  4. 池化层:夹在卷积层中间,用于压缩数据和参数的量,减小过拟合
  5. 全连接层:通常在卷积神经网络末尾,和传统神经网络的神经元连接方式一样

我们可以这样通俗地理解卷积神经网络(CNN):它就像是一位擅长发现图案和特征的魔法艺术家,专注于识别图像中的各种形状、边缘和纹理。它的大脑结构采用了一种独特的方式,通过卷积层和池化层的"魔法组合",让它能够从数据的细微之处捕捉信息,这个"魔法组合"包括了卷积运算和池化操作,它们共同作用于输入数据,以识别和提取出有用的特征,就像一位艺术家从细枝末节中汲取创作灵感。

想象一下,卷积层就像是这位艺术家的放大镜,艺术家拿着放大镜细心观察画布上的每一块小区域,并用自己独特的画笔勾勒出图案。而池化层则像是一位聪明的助手,帮助艺术家过滤掉不太重要的细节,让主要的特征更加突出。

这个神奇的艺术家不同于传统的绘画师,它不需要事先知道画面的精确描述,只需通过一些已知的样本进行训练,就能够快速适应新的图像,就像是学会了一种捕捉本质的直觉。这种独特的艺术风格使得 CNN 在图像识别、人脸检测和其他视觉任务中大放异彩,成为计算机视觉领域的一位杰出的艺术家和专家。

2.3.3 残差块

残差块(ResBlock)就像神奇的信息传递桥梁,通过引入跳跃连接,使信息能够更顺畅地穿越整个网络。这种跳跃连接不仅方便信息的传递,还有助于避免信息在网络中丢失或衰减。在图像去模糊任务中,感受野的重要性不言而喻,而 ResBlock 通过其独特的结构,有效地放大了感受野,极大地提高了处理大移动尺度造成的严重模糊的效果。

在去模糊任务中,残差块的存在让网络更容易收敛,减少了参数的数量,提高了训练效率。通过在编码器解码器结构中嵌入 ResBlock,网络不仅更好地保留了空间信息,而且更有效地还原了清晰的图像。

2.3.4 公式

引入了编码器解码器结构后,对 2.2.5 节的网络公式进行修正:

f i = N e t E ( B i , I i + 1 ↑ ; θ E ) , h i , g i = C o n v L S T M ( h i + 1 ↑ , f i ; θ L S T M ) , I i = N e t D ( g i ; θ D ) f^i=Net_E(B^i,I^{i+1↑};\theta_E), \\ h_i,g_i=ConvLSTM(h^{i+1↑},f^i;\theta_{LSTM}), \\ I^i=Net_D(g^i;\theta_D) fi=NetE(Bi,Ii+1;θE),hi,gi=ConvLSTM(hi+1,fi;θLSTM),Ii=NetD(gi;θD)

下面是公式符号的解释:

  • i i i:不同的尺度, i = 1 , 2 , 3 i=1,2,3 i=1,2,3,代表三个不同尺度, i = 1 i=1 i=1代表最后输出的分辨率最大的尺度
  • B i B^i Bi i i i尺度下输入的模糊图片
  • ↑ ↑ :上采样操作
  • N e t E Net_{E} NetE N e t E Net_{E} NetE是具有参数 θ E \theta_E θE的编码器 CNN,将输入图像转换为输入的 1/4 空间尺寸和 4 倍通道数的特征图 f i f^i fi
  • C o n v L S T M ConvLSTM ConvLSTM:每一层都有的卷积长短期记忆网络
  • h i h^i hi:当前尺度下 C o n v L S T M ConvLSTM ConvLSTM捕捉到的隐藏状态特征
  • g i g^i gi:当前尺度下 C o n v L S T M ConvLSTM ConvLSTM生成的特征图
  • N e t D Net_{D} NetD N e t D Net_{D} NetD是具有参数 θ D \theta_D θD的解码器 CNN,将经过 ConvLSTM 网络处理的特征图转换为当前尺度下输出的清晰图片 I i I^i Ii
  • I i I^i Ii:当前尺度下输出的清晰图片

从公式中能够更加直接地理解整个 SRN-DeblurNet 网络的思想。

  1. 上一层的输出 I i + 1 ↑ I^{i+1↑} Ii+1也要作为当前层的输入
  2. 模型训练参数 θ E \theta_{E} θE, θ L S T M \theta_{LSTM} θLSTM, θ E \theta_{E} θE跨尺度共享
  3. 上一层中间提取的特征 h i + 1 ↑ h^{i+1↑} hi+1与当前层获取到的特征融合得到当前层特征 h i h^{i} hi

2.4 损失函数

L = ∑ i = 1 n κ i N i ∣ ∣ I i − I ∗ i ∣ ∣ 2 2 L=\sum_{i=1}^{n}\frac{\kappa_i}{N_i}||I^i-I^i_*||^2_2 L=i=1nNiκi∣∣IiIi22

其中 κ i \kappa_i κi是每个 scale 的权重,默认 κ i \kappa_i κi为 1, N i N_i Ni是图像中像素的数量,起到归一化的作用。

这里用的是 L2 loss 的开平方,L2 loss 是较为理想的 loss 函数,因为可以更容易跑出高一点的 PSNR 值。但是 L2 loss 也容易受到异常值(outliers)的影响。

2.5 总结

现在,我们已经对 SRN-DeblurNet 整个模型有一个比较全面的了解了。是时候概括一下 SRN-DeblurNet 去模糊的全流程了:

作者采用了在”coarse-to-fine”方案中跨多个尺度的循环结构,将输入图像中在不同尺度下采样的模糊图像序列作为输入 ,生成一组对应的锐化图像,在全分辨率下最清晰的是最终的输出。具体来说,在 i 尺度下生成尺度 i 的清晰潜像 I i I^i Ii,当 i≠1 时,生成的潜像经过上采样处理作为下一尺度的输入,和下一尺度的模糊图像 B i − 1 B^{i-1} Bi1 一起输入到 SRN 网络中,重复上述的过程,当 i=1 时,生成的清晰潜像 I 1 I^1 I1就是最终的结果。

工作流的形式为:

  • B 3 B^3 B3 —> SRN—> I 3 I^3 I3
  • I 3 I^3 I3上采样后 + B 2 B^2 B2 —> SRN —> I 2 I^2 I2
  • I 2 I^2 I2上采样后 + B 1 B^1 B1 —> SRN —> I 1 I^1 I1

搭配下图一定能帮助你理解:

在这里插入图片描述

3 代码实现

官方代码 tensorflow 版本:https://github.com/firenxygao/deblur
参考代码:https://github.com/iwtw/SRN-DeblurNet
个人复现的代码:https://github.com/ClareAquarius/SRN_DeblurNet_Pytorch.git

3.1 数据集列表生成器和加载器

3.1.1 生成图像数据对

代码存放在 ./DataSet/data_list.py 中,这段代码主要目的是准备图像数据对,指示训练或评估时所需的输入图像和对应的目标输出图像,将模糊图像和对应的去模糊图像路径配对保存到一个文件,生成 tarin_list 和 eval_list,方便在训练神经网络模型时加载数据

# 训练集
prefix_dir = './DataSet'                   # 文件路径前缀
blur_dir = 'train/blur_lin'           # 模糊的图像数据集的路径
deblur_dir = 'train/deblurred_lin'    # 去模糊的图像数据集的路径
txt_name = './train_list'               # 保存的名称

# 评估集
# prefix_dir = './DataSet'                   # 文件路径前缀
# blur_dir = 'eval/blur'                  # 模糊的图像数据集的路径
# deblur_dir = 'eval/deblur'              # 去模糊的图像数据集的路径
# txt_name = './eval_list'                # 保存的名称

blur_img_list =[]
deblur_img_list = []
# 遍历模糊目录下的图像
for filename in os.listdir(blur_dir):
    # 获取文件的完整路径
    filepath = os.path.join(blur_dir, filename)
    # 判断是否为文件
    if os.path.isfile(filepath):
        blur_img_list.append(os.path.join(prefix_dir, filepath).replace("\\", "/"))
# 遍历去模糊目录下的图像
for filename in os.listdir(deblur_dir):
    # 获取文件的完整路径
    filepath = os.path.join(deblur_dir, filename)
    # 判断是否为文件
    if os.path.isfile(filepath):
        deblur_img_list.append(os.path.join(prefix_dir, filepath).replace("\\", "/"))

# 先检查数据数量上是否有问题
assert len(blur_img_list)==len(deblur_img_list), "用作训练集的模糊图像和去模糊图像应该一一匹配"

# 将数据组织成为一个list,使用文件保存
with open(txt_name, 'w') as file:
    for a,b in zip(blur_img_list,deblur_img_list):
        img_pairs = "{} {}\n".format(a, b)        # 一行数据表示一对模糊\去模糊的数据
        file.write(img_pairs)
file.close()

3.1.2 数据加载器

自定义训练数据集类 Dataset,用于加载训练数据。因为后续在训练过程中想 PyTorch 的 DataLoader 加载数据,因此需要这个类继承自 torch.utils.data.Dataset,必须实现 lengetitem 两个方法,而且为了更好地训练数据,在加载图像时加入了随机裁剪,以生成更多的训练数据。

Dataset 数据集对象,接收图像文件列表 img_list 和裁剪尺寸 crop_size 作为初始化参数。初始化时将输入的图像文件列表裁剪尺寸保存在对象的属性中。

getitem 方法中:根据给定的训练索引 idx 获取对应的图像数据。在这个方法中,首先根据索引获取模糊图像和对应的清晰图像,并确保它们的尺寸相同。然后,随机选择裁剪位置,将裁剪好的模糊和清晰图像下采样为原大小,1/2 规模,1/4 规模大小的三组张量数据。最后,将这些张量数据放入字典 batch 中,对图像进行归一化,使得它们的数值范围在 [-1, 1] 之间。

这个自定义数据集类的作用是对输入的图像数据进行加载、预处理、裁剪和转换,生成多组不同尺寸的模糊图像和对应的清晰图像的张量数据。这样设计的数据集类可以方便地配合 PyTorch 的 DataLoader 进行批处理、打乱和并行加载,用于神经网络模型的训练和评估。

# 自定义数据集类,目的是使用DataLoader可以对数据进行批处理、打乱和并行加载
# 需要创建一个继承自 torch.utils.DataSet.Dataset 的自定义数据集类。这个类应该至少包含两个方法:__len__ 返回数据集的大小,__getitem__ 返回给定索引的数据。

class Dataset(torch.utils.data.Dataset):
    _"""加载和处理训练集(含有blur和deblur对应图片,加入随机裁剪),将图片转化为张量"""_
_    _def __init__(self, img_list, crop_size=(256, 256)):
        _"""_
_        Args:_
_            img_list:  图像文件列表_
_            crop_size: 表示裁剪后的图像大小_
_        """_
_        _super(type(self), self).__init__()
        self.img_list = img_list
        self.crop_size = crop_size
        self.to_tensor = transforms.ToTensor()

    def crop_resize_totensor(self, img, crop_location):
        _"""_
_        根据裁剪位置(crop_location)从原图中裁剪出三个尺寸不同的图像(256x256、128x128和64x64大小),并转化为张量_
_        Args:_
_            img:接收一张图像_
_            crop_location:裁剪位置作为参数_
_        Returns:_
_            将256x256、128x128和64x64大小的三张图像转换为张量形式_
_        """_
_        _img256 = img.crop(crop_location)
        img128 = img256.resize((self.crop_size[0] // 2, self.crop_size[1] // 2), resample=Image.BILINEAR)
        img64 = img128.resize((self.crop_size[0] // 4, self.crop_size[1] // 4), resample=Image.BILINEAR)
        return self.to_tensor(img256), self.to_tensor(img128), self.to_tensor(img64)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        _"""_
_        根据给定的训练索引idx获取对应的图像数据_
_        Args:_
_            idx:    给定的训练数据索引(可能是多个)_
_        Returns:_
_            batch:  键值对,将256x256、128x128和64x64大小的三张图像转换为批量的张量形式(张量大小在[-1,1]之间)_
_        """_
_        _# 得到idx
        blurry_img_name = self.img_list[idx].split(' ')[-2]
        clear_img_name = self.img_list[idx].split(' ')[-1]
        blurry_img = Image.open(blurry_img_name)
        clear_img = Image.open(clear_img_name)
        # 断言判断模糊图像和清晰图像的size是否相同
        assert blurry_img.size == clear_img.size

        # 随机选择裁剪位置,裁剪出crop_size大小的图片
        # np.random.uniform是NumPy库中的一个随机数生成函数,用于从均匀分布中生成随机样本
        crop_left = int(np.floor(np.random.uniform(0, blurry_img.size[0] - self.crop_size[0] + 1)))
        crop_top = int(np.floor(np.random.uniform(0, blurry_img.size[1] - self.crop_size[1] + 1)))
        # 裁剪的位置(左侧,顶端,右端,底部)
        crop_location = (crop_left, crop_top, crop_left + self.crop_size[0], crop_top + self.crop_size[1])

        # 将裁剪好的图片下采样为256x256、128x128和64x64大小的三个张量
        img256, img128, img64 = self.crop_resize_totensor(blurry_img, crop_location)
        label256, label128, label64 = self.crop_resize_totensor(clear_img, crop_location)

        # 将3个size,模糊与清晰的六组图像数据放入字典batch中
        batch = {'img256': img256, 'img128': img128, 'img64': img64, 'label256': label256, 'label128': label128, 'label64': label64}
        for k in batch:
            batch[k] = batch[k] * 2 - 1.0  # in range [-1,1]
        return batch

下面是测试数据集类 ValDataset,用于加载测试数据,它与 Dataset 区别在于:ValDataset没有引入随机裁剪的操作,因为是测试数据,只是简单的读取完整的照片。

class ValDataset(torch.utils.data.Dataset):
    _"""加载和处理训练集(含有blur和deblur对应图片,不加入随机裁剪),将图片转化为张量"""_
_    _def __init__(self, img_list):
        _"""_
_        Args:_
_            img_list:  图像文件列表_
_        """_
_        _super(type(self), self).__init__()
        self.img_list = img_list
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        _"""_
_        根据给定的训练索引idx获取对应的图像数据_
_        Args:_
_            idx:    给定的训练数据索引(可能是多个)_
_        Returns:_
_            batch:  键值对,将256x256、128x128和64x64大小的三张图像转换为批量的张量形式(张量大小在[-1,1]之间)_
_        """_
_        _# 得到idx
        blurry_img_name = self.img_list[idx].split(' ')[-2]
        clear_img_name = self.img_list[idx].split(' ')[-1]
        img256 = Image.open(blurry_img_name)
        label256 = Image.open(clear_img_name)
        # 断言判断模糊图像和清晰图像的size是否相同
        assert img256.size == label256.size
        img_size = img256.size

        # 将裁剪好的图片下采样为256x256、128x128和64x64大小的三个张量
        img128 = img256.resize((img_size[0] // 2, img_size[1] // 2), resample=Image.BILINEAR)
        img64 = img128.resize((img_size[0] // 4, img_size[1] // 4), resample=Image.BILINEAR)

        label128 = img256.resize((img_size[0] // 2, img_size[1] // 2), resample=Image.BILINEAR)
        label64 = img128.resize((img_size[0] // 4, img_size[1] // 4), resample=Image.BILINEAR)

        # 将3个size,模糊与清晰的六组图像数据放入字典batch中
        batch = {'img256': self.to_tensor(img256), 'img128': self.to_tensor(img128), 'img64': self.to_tensor(img64),
                 'label256': self.to_tensor(label256), 'label128': self.to_tensor(label128), 'label64': self.to_tensor(label64)}
        for k in batch:
            batch[k] = batch[k] * 2 - 1.0  # in range [-1,1]
        return batch

3.2 网络基本模块

3.2.1 pytorch 模块包装

首先在 basic_block.py 代码中使用 pytorch 官方库包装了一些线性层、卷积层、转置卷积的基本模块,用于后续的模块使用,如下所示:

def linear(in_channels, out_channels, activation_fn=None, use_batchnorm=False, pre_activation=False, bias=True,
           weight_init_fn=None):
    _"""pytorch torch.nn.Linear包装器函数_
_        Args:_
_            in_channels:    输入通道数_
_            out_channels:   输出通道数_
_            activation_fn: 激活函数,如果需要任何参数,使用partial()来包装activation_fn_
_            use_batchnorm:  如果 use_batchnorm 参数为 True,则在 线性层前/后 添加批归一化层(nn.BatchNorm2d)_
_            pre_activation: 如果 pre_activation 为 True,则 批归一化层和激活函数 应用在卷积层之前_
_            bias:           是否需要偏置项_
_            weight_init_fn: 初始化函数,如果需要任何参数,请使用partial()来包装初始化函数。默认为None,如果为None,则根据activation_fn自动选择初始化函数。_

_        Example:_
_            linear(3, 32, activation_fn=partial(torch.nn.LeakyReLU, negative_slope=0.1))_
_        """_
        # 实现代码可见github
      
 def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, activation_fn=None, use_batchnorm=False,
         pre_activation=False, bias=True, weight_init_fn=None):
    _"""pytorch torch.nn.Conv2d的包装器_
_    Args:_
_        activation_fn: 激活函数,如果需要任何参数,使用partial()来包装activation_fn_
_        use_batchnorm:  如果 use_batchnorm 参数为 True,则在 卷积层前/后 添加批归一化层(nn.BatchNorm2d)_
_        pre_activation: 如果 pre_activation 为 True,则 批归一化层和激活函数 应用在卷积层之前_
_        bias:           是否需要偏置项_
_        weight_init_fn: 初始化函数,如果需要任何参数,请使用partial()来包装初始化函数。默认为None,如果为None,则根据activation_fn自动选择初始化函数。_

_    Examples:_
_        conv(3,32,3,1,1,activation_fn = partial( torch.nn.LeakyReLU , negative_slope = 0.1 ))_
_    """_   
    # 实现代码可见github
  
  def deconv(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, activation_fn=None,
           use_batchnorm=False, pre_activation=False, bias=True, weight_init_fn=None):
    _"""pytorch torch.nn.ConvTranspose2d的包装器_
_    Args:_
_        activation_fn: 激活函数,如果需要任何参数,使用partial()来包装activation_fn_
_        use_batchnorm:  如果 use_batchnorm 参数为 True,则在 转置卷积层前/后 添加批归一化层(nn.BatchNorm2d)_
_        pre_activation: 如果 pre_activation 为 True,则 批归一化层和激活函数 应用在卷积层之前_
_        bias:           是否需要偏置项_
_        weight_init_fn: 初始化函数,如果需要任何参数,请使用partial()来包装初始化函数。默认为None,如果为None,则根据activation_fn自动选择初始化函数。_
_    Examples:_
_        deconv(3,32,3,1,1,activation_fn = partial( torch.nn.LeakyReLU , negative_slope = 0.1 ))_
_    """_  
    # 实现代码可见github

3.2.2 ResBlock

下面是论文的 ResBlock,它由包含两个卷积层的残差块(Residual Block)组成,是之后的 EBlock 和 DBlock 的组成部分。残差块是深度残差网络(Residual Network)中的基本构建块,残差模块包含两个卷积层和一个跳跃连接,卷积层具有相同数量的卷积核。

在这里插入图片描述

下面的实现过程中 resblock,下面是实现的一些细节:

  • 使用 kernel_size=5x5,padding=2,stride=1 进行特征提取,不会修改图像的形状(h,w)
  • 输入和输出的通道数相同,以保持信息流的连续性(即 in_channelsout_channels 相同)。
  • 不应用批量归一化(BN),这在某些情况下可以使模型学习更稳定,因为 BN 可能会引入一些额外的噪声。
  • 不使用最后的激活函数(last_activation_fn=None),这允许在该残差块的输出后续连接其他层时再应用激活函数,例如,在整个网络的最后一层。
class BasicBlock(nn.Module):
    _"""pytorch torch.nn.conv2d wrapper_
_        含有两个卷积层的残差块(Resblock),由于可能会改变(c,h,w),因此残差块使用1x1卷积层之后连接_

_    Args:_
_        use_batchnorm:      如果 use_batchnorm 参数为 True,则在子层前/后 添加批归一化层(nn.BatchNorm2d)_
_        activation_fn:      激活函数,如果需要任何参数,使用partial()来包装activation_fn_
_        last_activation_fn: 最后一层卷积层之后使用的激活函数_
_        pre_activation:     如果 pre_activation 为 True,则 批归一化层和激活函数 应用在卷积层之前_
_        bias:               是否需要偏置项_
_        weight_init_fn:     初始化函数,如果需要任何参数,请使用partial()来包装初始化函数。默认为None,如果为None,则根据activation_fn自动选择初始化函数。_
_    Examples:_
_        BasicBlock(32, 32, activation_fn = partial( torch.nn.LeakyReLU , negative_slope = 0.1 , inplace = True ))_
_    """_

_    _def __init__(self, in_channels, out_channels, kernel_size, stride=1, use_batchnorm=False,
                 activation_fn=partial(nn.ReLU, inplace=True), last_activation_fn=partial(nn.ReLU, inplace=True),
                 pre_activation=False, scaling_factor=1.0):
        super(BasicBlock, self).__init__()
        # 第一个卷积层(缩小stride倍,通道:in_channels--->out_channels),使用激活函数(activation_fn)
        self.conv1 = conv(in_channels, out_channels, kernel_size, stride, kernel_size // 2, activation_fn,
                          use_batchnorm)
        # 第二个卷积层(channel,h,w都未改变),不使用激活函数(activation_fn)
        self.conv2 = conv(out_channels, out_channels, kernel_size, 1, kernel_size // 2, None, use_batchnorm,
                          weight_init_fn=get_weight_init_fn(last_activation_fn))
        # 如果是下采样,残差块需要使用1x1卷积层来改变通道数
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = conv(in_channels, out_channels, 1, stride, 0, None, use_batchnorm)
        # 在最后添加激活函数
        if last_activation_fn is not None:
            self.last_activation = last_activation_fn()
        else:
            self.last_activation = None
        # 残差块的影响因子
        self.scaling_factor = scaling_factor

    def forward(self, x):
        # 计算残差块,如果是下采样需要进入1x1卷积层改变(c,h,w)
        residual = x
        if self.downsample is not None:
            residual = self.downsample(residual)
        # 第一个第二个卷积层输出
        out = self.conv1(x)
        out = self.conv2(out)
        # 残差块合并
        out += residual * self.scaling_factor
        # 激活函数
        if self.last_activation is not None:
            out = self.last_activation(out)
        return out    
        
  def resblock(in_channels):
    _"""Resblock 使用5x5卷积核,通道数不改变,不使用批量归一化(BN)和 最后的激活函数(the last activation)"""_
_    _return BasicBlock(in_channels, out_channels=in_channels, kernel_size=5, stride=1, use_batchnorm=False,
                      activation_fn=partial(nn.ReLU, inplace=True), last_activation_fn=None)

3.2.3 CLSTM_cell

在 pytorch 库中,只有一维的 LSTM 标准库,没有用卷积实现的 CONV_LSTM 库函数,因此需要效仿 LSTM 手动实现二维卷积版本的 CONV_LSTM 模块。在 conv_lstm.py 中,虽然实现了 CLSTM_cell(单个 Conv LSTM 单元)和 CLSTM(多层的 con_CLSTM 模块),但是在论文中只采用了单层的 CLSTM_cell****。

在实现过程中,注意点有几个:

  1. hidden_state 参数的组织形式:hidden_state 是一个包含隐藏状态 hidden 和记忆 c 的列表或元组。在使用时,确保以正确的顺序传递隐藏状态和记忆状态,以便在 forward 方法中能够正确地解包使用。
  2. 隐藏状态 hidden 和 记忆状态 c 的形状:hiddenc 应该具有与输入 input 相同的图像形状(高度和宽度)。这是因为在 ConvLSTM 中,隐藏状态和记忆状态通常具有与输入数据相同的空间维度,以便能够有效地进行卷积运算和元素级操作。
  3. init 方法的使用:在初始化 ConvLSTM 单元时,使用 init_hidden 方法创建初始的隐藏状态和记忆状态。确保传递正确的批量大小(batch_size)和图像形状(shape),并将这些初始状态作为 hidden_state 的初始值传递给 forward 方法。
class CLSTM_cell(nn.Module):
    _"""定义单个Conv LSTM单元._
_    Args:_
_      input_chans:  int 输入的通道数_
_      num_features: int 状态的通道数(表示c,w隐藏状态等的通道数)_
_      filter_size:  int 过滤器(卷积核)的高度和宽度_
_    """_
_    _# input_chans是输入通道数,num_features是特征数
    def __init__(self, input_chans, num_features, filter_size):
        super(CLSTM_cell, self).__init__()
        self.input_chans = input_chans
        self.filter_size = filter_size
        self.num_features = num_features
        # 计算需要多大的padding才能保证输入输出图像不变
        self.padding = (filter_size - 1) // 2
        self.conv = nn.Conv2d(self.input_chans + self.num_features, 4 * self.num_features, self.filter_size, 1,
                              self.padding)

    # foward需要传入input和hidden_state
    def forward(self, input, hidden_state):
        _"""_
_        Args:_
_            input:          shape:(B,C,H,W),其中C是input_chans_
_            hidden_state:   shape:(Batch, Chans, H, W),h,c两个具有是num_features通道,相同与图像相同形状的隐状态_
_        Returns:_
_            next_h, next_c: shape:(Batch, Chans, H, W),循环一次后下一个(h,c隐状态)_
_        """_
_        _hidden, c = hidden_state
        # 将输入i和隐状态h 在通道维度上拼接,形状为(b,c,h,w)
        combined = torch.cat((input, hidden), 1)
        A = self.conv(combined)
        # 分别得到输入门(决定何时将数据读入单元),遗忘门(用来重置单元的内容),输出门(用来从单元中输出条目),候选记忆单元
        (ai, af, ao, ag) = torch.split(A, self.num_features, dim=1)
        i = torch.sigmoid(ai)
        f = torch.sigmoid(af)
        o = torch.sigmoid(ao)
        g = torch.tanh(ag)
        # 计算下一个记忆元c
        next_c = f * c + i * g
        # 计算下一个隐状态h
        next_h = o * torch.tanh(next_c)
        return next_h, next_c

    def init_hidden(self, batch_size, shape):
        # 用于初始化单个Conv LSTM单元的h和c隐状态
        return (torch.zeros(batch_size, self.num_features, shape[0], shape[1]).cuda(),
                torch.zeros(batch_size, self.num_features, shape[0], shape[1]).cuda())

3.2.4 EBlock/InBlock

EBlock 模块的作用通常是用于神经网络的 编码器 部分,用于提取输入数据的特征表示。

编码器模块包含一个卷积层,然后是三个残差模块,其中卷积层的 stride=2,将上一层输入的通道数翻倍,并将特征图下采样为原来的一半;而后面的 ResBlock 保持输入输出图像形状、通道数不变

在这里插入图片描述

class EBlock(nn.Module):
    _"""编码器(EBlock)由一个5x5conv层+三个Resblock组成,和InBlock输入块组成相同"""_
_    _def __init__(self, in_channels, out_channels, stride):
        super(type(self), self).__init__()
        # 5x5conv层
        self.conv = conv5x5_relu(in_channels, out_channels, stride)
        # 3个ResBlock块
        resblock_list = []
        for i in range(3):
            resblock_list.append(resblock(out_channels))
        self.resblock_stack = nn.Sequential(*resblock_list)

    def forward(self, x):
        x = self.conv(x)
        x = self.resblock_stack(x)
        return x

EBlockInBlock 的模块组成相同,只是 in_channels, out_channels 的参数不同,因此可以 EBlock 也是 InBlock。

3.2.5 DBlock

DBlock 模块的作用通常是用于神经网络的 解``码器 部分,用于将编码器提取的特征逐步转化为目标。

事实上:解码器模块与编码器模块在结构上是对称的,它包含三个残差模块,然后是一个反卷积层,反卷积层用于将特征图的空间大小加倍,并将通道数减半。

在这里插入图片描述

class DBlock(nn.Module):
    _"""解码器(DBlock)由三个Resblock+一层deconv层组成"""_
_    _def __init__(self, in_channels, out_channels, stride, output_padding):
        super(type(self), self).__init__()
        # 3个ResBlock块
        resblock_list = []
        for i in range(3):
            resblock_list.append(resblock(in_channels))
        self.resblock_stack = nn.Sequential(*resblock_list)
        # 5x5deconv层
        self.deconv = deconv5x5_relu(in_channels, out_channels, stride, output_padding)

    def forward(self, x):
        x = self.resblock_stack(x)
        x = self.deconv(x)
        return x

3.2.6 OutBlock

OutBlock 的输出块模块,由三个残差块(Resblock)和一层卷积层(conv 层)组成。通常用于神经网络的输出部分,将特征映射转换为最终的通道数为 3 的输出图像。

  • OutBlock 中,首先通过三个残差块 resblock 对输入张量进行特征提取和转换,其中每个残差块都保持了输入张量的通道数 in_channels。因此,这三个残差块处理后的张量的通道数仍然保持为 in_channels
  • 接着,经过三个残差块处理后的特征张量传递到一个卷积层 conv 中,这个卷积层使用 5x5 的卷积核,步幅为 1,填充为 2,将输入的通道数 in_channels 转换为输出的通道数 3(即将特征通道数变为 3)。
class OutBlock(nn.Module):
    _"""输出块(OutBlock)由三个Resblock+一层conv层组成,将通道数in_channels变为3"""_
_    _def __init__(self, in_channels):
        super(type(self), self).__init__()
        resblock_list = []
        for i in range(3):
            resblock_list.append(resblock(in_channels))
        self.resblock_stack = nn.Sequential(*resblock_list)
        self.conv = conv(in_channels, 3, 5, 1, 2, activation_fn=None)

    def forward(self, x):
        x = self.resblock_stack(x)
        x = self.conv(x)
        return x

3.3 SRN 网络介绍

有了在 3.2 定义的基本模块,此时实现 SRN 网络就相对容易了许多,

  • 网络结构概述:

    • 输入:该网络接受三种不同尺度的图像(原始尺度、1/2 尺度和 1/4 尺度)。
    • 输出:与输入相对应的去模糊后的图像。
  • 网络组成:

    • 输入块:处理三种尺度的图像,并将其组合成一个特征张量。
    • 编码块:一系列编码层,逐步提取和压缩特征。
    • ConvLSTM 单元:具有记忆功能的卷积长短期记忆模块,用于捕捉图像序列中的时间和空间相关性。
    • 解码块:一系列解码层,通过上采样逐步恢复图像尺寸。
    • 输出块:生成最终的去模糊图像。

下面代码中的 forward_step 函数承担了整个网络模型中的在单个尺度上的单步向前传播,负责从编码到解码的信息流转换以及隐藏状态的传递。特征传递和上采样:首先通过特征传递和上采样来跨尺度传递信息。每次迭代通过上一次尺度的图像和当前尺度的特征信息来更新结果。

  • 将输入 x 经过输入块(self.inblock)和两个编码块(self.eblock1self.eblock2)进行特征提取和编码。这些操作将逐渐减少特征的空间维度,并增加通道数。
  • 使用 ConvLSTM 模块处理编码后的特征图 e128 和传入的隐藏状态 hidden_state,这里返回了 ConvLSTM 模块的隐藏状态 hc。这两个状态在形状上与输入 e128 相同。
  • 将 ConvLSTM 模块的隐藏状态 h 输入到两个解码块(self.dblock1self.dblock2)进行特征恢复和解码操作。
  • 最后通过输出块(self.outblock)产生最终的输出 d3,这里使用了残差块将解码结果和编码过程中的特征进行残差连接以提高网络学习能力。
class SRNDeblurNet(nn.Module):
    _"""SRN-DeblurNet主体网络_
_    Examples:_
_        net = SRNDeblurNet()_
_        y = net( x1 , x2 , x3)#x3是最粗糙的图像,而x1是最精细的图像_
_    """_

_    _def __init__(self, upsample_fn=partial(torch.nn.functional.interpolate, mode='bilinear'), xavier_init_all=True):
        super(type(self), self).__init__()
        self.upsample_fn = upsample_fn          # 下采样方法(upsample_fn)是双线性插值(bilinear)
        self.input_padding = None               # 记录上轮的图片输出

        # 输入块
        self.inblock = EBlock(3 + 3, 32, 1)     # 这里的3+3意思是原本输入图像具有3通道,从上一个输出图像具有3通道
        # 编码块(通道c倍增,高h宽w减半)
        self.eblock1 = EBlock(32, 64, 2)
        self.eblock2 = EBlock(64, 128, 2)

        # convlstm单层
        self.convlstm = CLSTM_cell(128, 128, 5)

        # 解码块(通道c倍减,高h宽w翻倍)
        self.dblock1 = DBlock(128, 64, 2, 1)
        self.dblock2 = DBlock(64, 32, 2, 1)
        # 输出块
        self.outblock = OutBlock(32)

        # 初始化参数
        if xavier_init_all:
            for name, m in self.named_modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                    torch.nn.init.xavier_normal_(m.weight)

    def forward_step(self, x, hidden_state):
        _"""单步forward_
_        Args:_
_            x:      (b,c,h,w),其中c是6通道(3通道+3通道)_
_        Returns:_
_            d3:     (b,c,h,w),其中c是3通道_
_            h,c:    (b,c,h,w),其中c为128通道_
_        """_
_        _# 输入块+编码块(通道6(3+3)->32->64->128,h和w在两层编码块变为h/4,w/4)
        e32 = self.inblock(x)
        e64 = self.eblock1(e32)
        e128 = self.eblock2(e64)
        # convlstm
        h, c = self.convlstm(e128, hidden_state)        # 返回convlstm的h和c隐状态,其形状与e128相同
        # 解码块+输出块(通道128->64->32->3,h/4和w/4在两层解码块变为h和w)
        d64 = self.dblock1(h)
        d32 = self.dblock2(d64 + e64)   # 含残差块
        d3 = self.outblock(d32 + e32)   # 含残差块
        return d3, h, c

    def forward(self, b1, b2, b3):
        _"""三次不同规模的forward_
_        Arg:_
_            b1, b2, b3: 原规模,1/2规模,1/4规模的图片_
_        Return:_
_            i1, i2, i3: 经过网络后的原规模,1/2规模,1/4规模的图片_
_        """_

_        _# input_padding是第一次用于填充1/4规模的输入图片
        if self.input_padding is None or self.input_padding.shape != b3.shape:
            self.input_padding = torch.zeros_like(b3)
        # 初始化h,c隐状态(B=b1.shape[0],C=128,H=1/16原H,W=1/16原W)
        # 为什么这里是1/16?因为第一次进入的b3本身就是1/4规模的图片,经过两层编码块后,h和w会2次减半
        h, c = self.convlstm.init_hidden(b1.shape[0], (b1.shape[-2]//16, b1.shape[-1]//16))

        # 第一轮迭代(1/4规模),将b3和input_padding拼接输入
        i3, h, c = self.forward_step(torch.cat([b3, self.input_padding], 1), (h, c))
        # 下一次的h和w隐状态形状:高H=1/8原H,宽W=1/8原W,需要上采样
        c = self.upsample_fn(c, scale_factor=2)
        h = self.upsample_fn(h, scale_factor=2)

        # 第二轮迭代(1/2规模),将b2和i3上采样2倍后拼接输入
        i2, h, c = self.forward_step(torch.cat([b2, self.upsample_fn(i3, scale_factor=2)], 1), (h, c))
        # 下一次的h和w隐状态形状:高H=1/4原H,宽W=1/4原W,需要上采样
        c = self.upsample_fn(c, scale_factor=2)
        h = self.upsample_fn(h, scale_factor=2)

        # 第三轮迭代(原规模)
        i1, h, c = self.forward_step(torch.cat([b1, self.upsample_fn(i2, scale_factor=2)], 1), (h, c))

        return i1, i2, i3

3.4 train.py 介绍

3.4.1 损失函数

计算模型输出图像与训练数据集图像之间的损失情况,通过均方误差(MSE)来度量输出图像与真实清晰图像之间的差异,并计算峰值信噪比(PSNR)以评估图像重建的质量。

M S E = 1 m n ∑ i = 0 m − 1 ∑ j = 0 n − 1 [ I ( i , j ) − K ( i , j ) ] 2 MSE=\frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1}[I(i,j)-K(i,j)]^2 MSE=mn1i=0m1j=0n1[I(i,j)K(i,j)]2

P S N R = 10 ∗ log ⁡ 10 ( M A X I 2 M S E ) PSNR=10*\log_{10}(\frac{MAX^2_I}{MSE}) PSNR=10log10(MSEMAXI2)

上面的 MSE 是均方误差: L = ∑ i = 1 n κ i N i ∣ ∣ I i − I ∗ i ∣ ∣ 2 2 L=\sum_{i=1}^{n}\frac{\kappa_i}{N_i}||I^i-I^i_*||^2_2 L=i=1nNiκi∣∣IiIi22

下面的 MAX_DIFF 是一个设定的常量,表示图像可能的最大像素值,由于输入时我们将 0~255 的像素映射到-1~1。因此,损失函数中将 MAX_DIFF 取 2 图片可能的最大像素值。

下面我们中我们使用各规模图像的 MSE 之和作为损失函数,而 PSNR 作为评估图像重建质量的指标

def compute_loss(db256, db128, db64, batch):
    _"""_
_    计算一个batch的损失函数_
_    Args:_
_        db256, db128, db64: 通过模型得到的256*256,128*128,64*64去模糊图像_
_        batch:              键值对,包含三个不同规模的批量训练图像的模糊和去模糊的张量_
_    Return:_
_        键值对:mse表示不同规模的均方误差,psnr表示在最大规模图片上的峰值信噪比(PSNR)_

_    """_
_    _assert db256.shape[0] == batch['label256'].shape[0]

    loss = 0
    loss += mse(db256, batch['label256'])
    # 峰值信噪比(PSNR)
    psnr = 10 * torch.log(MAX_DIFF ** 2 / loss) / log10
    loss += mse(db128, batch['label128'])
    loss += mse(db64, batch['label64'])
    return {'mse': loss, 'psnr': psnr}

3.4.2 学习率递减策略

在训练时使用了学习率衰减策略,采用学习率衰减策略,可以使得模型具有更快的收敛,避免震荡,减少过拟合风险的优点。

具体方法是根据 epoch 设置优化器 optimizer 的学习率,即周期越大,学习率越低,学习率 L R LR LR随训练周期 t t t衰减的函数为:

L R ( t ) = L R 0 × γ ⌊ t T ⌋ LR(t) = LR_0 \times \gamma ^ {\left\lfloor \frac{t}{T} \right\rfloor} LR(t)=LR0×γTt

其中:

  • L R ( t ) LR(t) LR(t)表示训练周期为 t t t时的学习率。
  • L R 0 LR_0 LR0是初始学习率。
  • γ \gamma γ是衰减系数,用于控制学习率的衰减速度。
  • ⌊ t T ⌋ \left\lfloor \frac{t}{T} \right\rfloor Tt表示 t T \frac{t}{T} Tt的整数部分,其中 T T T是一个周期,表示衰减的频率。
def set_learning_rate(optimizer, epoch):
    _"""_
_    使用了学习率衰减策略,根据epoch设置优化器optimizer的学习率,即周期越大,学习率越低_
_    Arg:_
_        optimizer: 优化器_
_        epoch:     当前训练的周期数_
_    """_
_    _optimizer.param_groups[0]['lr'] = config.train['learning_rate'] * 0.3 ** (epoch // 500)

3.4.3 训练过程

下面是 SRN 深度学习模型的训练与验证过程,其中训练参数保存在 config.py 文件中,主要完成了下面的任务:

  1. 数据加载:

    • 从文件中读取训练集和验证集的图像列表,并使用自定义的数据集类(DatasetValDataset)分别将训练图像和测试图像加载到数据加载器(DataLoader)中。
    • 数据加载器用于在训练和验证过程中对数据进行批处理、打乱和并行加载。
  2. 损失函数和优化器:

    • 定义了均方误差损失函数(MSELoss)用于计算模型预测与真实标签之间的差异。
    • 使用 Adam 或 SGD 优化器对神经网络模型参数进行优化。
  3. 模型训练和验证循环:

    • 通过循环遍历训练集数据,进行模型训练。每个批次通过网络后向传播计算损失,并进行参数更新。
    • 验证过程用于评估模型在验证集上的性能。通过禁用梯度计算,使用验证集的数据对模型进行评估,计算损失和其他指标(如 PSNR - 峰值信噪比)。
    • 记录并输出训练和验证集上的损失值、性能指标等。
  4. 学习率衰减:

    • set_learning_rate 函数实现了学习率的指数衰减策略,随着训练周期的增加,逐渐减小学习率。
  5. 模型保存:

    • 保存在验证集上获得最佳 PSNR 值时的模型参数。
  6. 输出信息:

    • 输出了训练过程中每个周期的训练损失和验证损失,以及其他指标(如训练和验证速度)。
if __name__ == "__main__":
    # 1.读入数据集,并放入数据加载器中
    # 读入训练集和测试集数据,其中.read() 读取文件的全部内容为一个字符串,.strip() 用于去除字符串两端的空格和换行符等空白字符。
    train_img_list = open(config.train['train_img_list'], 'r').read().strip().split('\n')
    val_img_list = open(config.train['val_img_list'], 'r').read().strip().split('\n')
    # 将数据集放入DataLoader(数据加载器)中
    train_dataset = Dataset(train_img_list)
    val_dataset = ValDataset(val_img_list)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train['batch_size'],
                                                   shuffle=True, drop_last=True, num_workers=8, pin_memory=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train['val_batch_size'],
                                                 shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
    # 2.均方误差
    mse = torch.nn.MSELoss().cuda()

    # 3.网络
    net = torch.nn.DataParallel(SRNDeblurNet(xavier_init_all=config.net['xavier_init_all'])).cuda()
    # 如果使用之前已训练的网络参数
    if config.train['if_use_pretrained_model']:
        checkpoints = torch.load(config.train['used_params_dir'])
        net.load_state_dict(checkpoints)

    # 4.优化器
    assert config.train['optimizer'] in ['Adam', 'SGD']
    if config.train['optimizer'] == 'Adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=config.train['learning_rate'],
                                     weight_decay=config.loss['weight_l2_reg'])
    if config.train['optimizer'] == 'SGD':
        optimizer = torch.optim.SGD(net.parameters(), lr=config.train['learning_rate'],
                                    weight_decay=config.loss['weight_l2_reg'], momentum=config.train['momentum'],
                                    nesterov=config.train['nesterov'])

    # 定义一些用于记录训练的参数
    train_loss_log_list = []        # 用于记录log记录时训练集的损失值
    val_loss_log_list = []          # 用于记录log记录时的验证集损失值
    first_val = True
    t = time()

    # 定义一些用于验证集上最优的参数和模型
    best_val_psnr = 0               # 记录目前为止在验证集上达到的最佳PSNR(峰值信噪比)值
    best_net = None                 # 验证过程中达到最佳 PSNR 时的模型
    best_optimizer = None           # 验证过程中达到最佳 PSNR 时的优化器

    for epoch in tqdm(range(config.train['num_epochs']), file=sys.stdout, desc=str(config.train['num_epochs'])+' epoches'):
        # 根据当前 epoch 设置学习率
        set_learning_rate(optimizer, epoch)

        # 训练
        for step, batch in enumerate(train_dataloader):
            # 这里的batch是键值对,但是值的第0维度是bathsize
            # 将batch数据移到GPU上,不需要计算梯度
            for k in batch:
                batch[k] = batch[k].cuda()
                batch[k].requires_grad = False
            # 得到网络预测结果
            db256, db128, db64 = net(batch['img256'], batch['img128'], batch['img64'])
            # 计算损失
            loss = compute_loss(db256, db128, db64, batch)

            # 反向传播和网络参数更新
            backward(loss, optimizer)

            # 将loss从gpu移动到cpu上
            for k in loss:
                loss[k] = float(loss[k].cpu().detach().numpy())
            # 记录训练的损失值
            train_loss_log_list.append({k: loss[k] for k in loss})


        # 验证(间隔log_epoch个周期验证一次)
        if first_val or epoch % config.train['log_epoch'] == config.train['log_epoch'] - 1:
            first_val = False
            # 验证时不需要记录梯度
            with torch.no_grad():
                for step, batch in enumerate(val_dataloader):
                    for k in batch:
                        batch[k] = batch[k].cuda()
                        batch[k].requires_grad = False
                    db256, db128, db64 = net(batch['img256'], batch['img128'], batch['img64'])
                    loss = compute_loss(db256, db128, db64, batch)
                    for k in loss:
                        loss[k] = float(loss[k].cpu().detach().numpy())
                    val_loss_log_list.append({k: loss[k] for k in loss})
                # 计算了训练损失(MSE和)的平均值
                train_loss_log_dict = {k: float(np.mean([dic[k] for dic in train_loss_log_list])) for k in
                                       train_loss_log_list[0]}
                val_loss_log_dict = {k: float(np.mean([dic[k] for dic in val_loss_log_list])) for k in
                                     val_loss_log_list[0]}


                # PSNR的值越大越好
                if best_val_psnr < val_loss_log_dict['psnr']:
                    best_val_psnr = val_loss_log_dict['psnr']   # 保存最优的PSNR值
                    best_net = net.state_dict()                 # 更新最优模型参数

                # 将训练集和测试集的损失列表清空
                train_loss_log_list.clear()
                val_loss_log_list.clear()

                tt = time()
                log_msg = ""
                log_msg += "epoch {} , {:.2f} imgs/s".format(epoch, (
                            config.train['log_epoch'] * len(train_dataloader) * config.train['batch_size'] + len(
                        val_dataloader) * config.train['val_batch_size']) / (tt - t))

                log_msg += " | train : "
                for idx, k_v in enumerate(train_loss_log_dict.items()):
                    k, v = k_v
                    if k == 'acc':
                        log_msg += "{} {:.3%} {}".format(k, v, ',')
                    else:
                        log_msg += "{} {:.5f} {}".format(k, v, ',')
                log_msg += "  | eval : "
                for idx, k_v in enumerate(val_loss_log_dict.items()):
                    k, v = k_v
                    if k == 'acc':
                        log_msg += "{} {:.3%} {}".format(k, v, ',')
                    else:
                        log_msg += "{} {:.5f} {}".format(k, v, ',' if idx < len(val_loss_log_list) - 1 else '')
                tqdm.write(log_msg, file=sys.stdout)
                sys.stdout.flush()
                t = time()


    print("最优PSNR为:", best_val_psnr)
    torch.save(best_net, config.train['save_params_dir'])

3.5 测试过程

测试过程相对简单,因为没有计算误差,只需要简单地调用已经保存的模型参数,对给定改的数据进行预测即可,相比较训练过程,测试过程相对简单,测试参数也保存在 config.py 文件中。

  1. 获取测试图像列表:

    • 通过 get_test_list() 函数获取测试图像文件夹中的图像列表。
  2. 图像处理和转换:

    • to_tensor_list(filepath) 函数用于读取给定路径的图像,并将其转换为模型可接受的张量格式。这里对图像进行了三种不同尺度的预处理,分别为原始尺度、0.5 倍尺度和 0.25 倍尺度。这三种尺度的图像经过转换为张量后,被组合成一个批次(batch)。
  3. 模型加载和推理:

    • 首先,加载了训练好的模型参数,并将模型参数加载到定义好的神经网络模型中。
    • 对测试集中的每张图像进行处理和推理:
      • 调用 to_tensor_list() 将测试图像转换为模型输入所需的张量格式。
      • 使用经过训练的模型对张量进行推理,获得模型对图像的预测结果 db256
      • 对模型输出的张量进行了后处理,包括去除批次维度、调整张量范围到[-1, 1]之间,并将张量转换为图像格式。
  4. 结果保存:

    • 将模型输出的图像结果保存到指定的输出目录中,命名规则为 output_prefix + input_name_list
def get_test_list():
    input_dir = config.test['input_dir']
    input_name_list = []
    input_filepath_list = []
    for filename in os.listdir(input_dir):
        # 获取文件的完整路径
        filepath = os.path.join(input_dir, filename).replace("\\", "/")
        # 判断是否为文件
        if os.path.isfile(filepath):
            input_name_list.append(filename)
            input_filepath_list.append(filepath)
    return input_name_list, input_filepath_list


def to_tesnor_list(filepath):
    _"""读取filepath处的图片"""_
_    _blurry_img = Image.open(filepath)
    size = blurry_img.size
    to_tesnor = transforms.ToTensor()

    # 原规模,0.5规模,0.25规模
    img1 = blurry_img
    img2 = img1.resize((size[0] // 2, size[1] // 2), resample=Image.BICUBIC)
    img3 = img2.resize((size[0] // 4, size[1] // 4), resample=Image.BICUBIC)
    # 转换为tensor,并增加第0维
    img1 = torch.unsqueeze(to_tesnor(img1), 0)
    img2 = torch.unsqueeze(to_tesnor(img2), 0)
    img3 = torch.unsqueeze(to_tesnor(img3), 0)
    # 组合成为batch字典输出
    batch = {'img256': img1, 'img128': img2, 'img64': img3}
    for k in batch:
        batch[k] = batch[k] * 2 - 1.0  # in range [-1,1]
    return batch


if __name__ == "__main__":
    # 加载网络结构
    net = torch.nn.DataParallel(SRNDeblurNet(xavier_init_all=config.net['xavier_init_all'])).cuda()
    checkpoints = torch.load(config.test['model_params'])
    net.load_state_dict(checkpoints)

    # 加载模型参数
    input_name_list, input_filepath_list = get_test_list()
    to_tesnor_list(input_filepath_list[0])

    # 处理每张照片
    for index, input_filepath in enumerate(input_filepath_list):
        batch = to_tesnor_list(input_filepath)
        with torch.no_grad():
            db256, _, _ = net(batch['img256'], batch['img128'], batch['img64'])

            # 删去batchsize维度,限制到Imaes的Tensor(-1,1)
            db256 = torch.squeeze(db256, 0).clamp(-1, 1)
            db256 = (db256 + 1) / 2

            # 从tensor转换到image
            to_pil = transforms.ToPILImage()
            new_img = to_pil(db256)

            # 展示与保存
            # new_img.show()
            new_img.save(config.test['output_dir'] + "/" + config.test['output_prefix'] + input_name_list[index], "PNG")  # 保存为PNG格式

3.6 一些问题/体会

3.6.1 张量范围处理

在读取图片的时候,我们将[0.255]的像素归一化到了[-1,1],因此在输出结果应该在[-1,1]之间,最后我们将[-1,1]的图片重新映射为[0,255]但实际运行过程中。

  • Clamp 操作:在张量范围超出[-1, 1]时,可以使用 torch.clamp() 方法将超出范围的值限制在指定范围内。例如,对模型输出的张量进行限制范围操作,确保其值在合理的范围内。
db256 = db256.clamp(-1, 1)
  • 反归一化:在输出图像之前,将[-1,1]的张量进行反归一化操作,确保最终输出的图像像素值范围在[0, 255]内,确保输出图像的像素值是合法的。
# 反归一化
db256 = (db256 + 1) / 2  # 将[-1, 1]范围的张量重新映射到[0, 1]范围
db256 = db256 * 255  # 将[0, 1]范围的张量映射到[0, 255]范围

3.6.2 处理尺寸不一致的问题

在 conv_LSTM 中,隐状态要求和输入图片的尺寸大小应该一致,但是在 SRN 中,每次输入的图片规模都比之前扩大一倍,怎么解决这个问题?

在不同规模的循环中,我们需要对隐藏状态进行上采样,这样每次在输入时,保存下次输入和隐藏状态的 shape 能够匹配

3.6.3 模型处理不同尺寸图片的原因

思考:为什么该模型能够对于不同尺寸(size)的图片进行处理?为什么能对不同尺度的图像进行去模糊处理?

  • 卷积操作的灵活性:模型中使用了卷积,转置卷积等操作,卷积操作的优势在于并没有显式限定输入图像的尺寸。卷积核在图像上进行滑动,尺寸并不影响卷积核的计算。这使得模型具有对不同尺寸图片的适应能力。
  • 多尺度特征提取:模型使用了编码器-解码器结构,通过堆叠多层编码和解码块,可以提取多个尺度的特征。这些特征在不同尺度上捕获了图像的细节和整体信息,有助于对不同尺度图像的去模糊处理。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值