图像去噪 是在 【深度学习】计算机视觉(1—13) 系列文章的基础上,再次针对深度学习(尤其是图像去噪方面)的基础知识有更深入学习和巩固。
1 DnCNN
通过最近的学习,我发现了解一个网络最好的方法就是看代码。无论是看论文、看别人讲解还是自己复现,效果远不如自己读代码。因此,建议在学习任何网络的时候多看代码。
1.1 网络结构
第一部分:Conv(3 * 3 * c * 64) + ReLU;
第二部分:Conv(3 * 3 * 64 * 64) + BN(batch normalization) + ReLU;
第三部分:Conv(3 * 3 * 64)。
每一层都使用零填充,以此防止产生人工边界。
在许多低级视觉应用中,通常要求输出图像大小应与输入图像大小保持一致。这可能会导致边界伪影。DnCNN在卷积之前直接填充零,以确保中间层的每个特征图都与输入图像具有相同的大小。简单的零填充策略不会导致任何边界伪影。
1.1.1 残差学习
以前以为所有与残差有关的都叫残差网络,但是残差实际上只表示一个公式。
DnCNN的输入是噪声观察y = x + v
,模型内部并不像ResNet一样存在跳远连接,而是在网络的输出使用残差学习,采用残差学习公式来训练残差映射R(y) ≈ v
,得到噪声图像,然后由 x = y – R(y)
得到原始图像。
1.1.2 Batch Normalization (BN)
1.1.2.1 背景和目标
批归一化是DnCNN第二个特点。在阅读代码的时候,我对model.train()
和model.eval()
产生了疑问,它们的作用是什么?
一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。
- model.train():启用BN层和Dropout。
- model.eval():不启用BN层和Dropout。
Dropout此处不做过多解释,对于BN层,model.train()是保证BN层能够用到每一批数据的均值和方差;model.eval()是保证BN层能够用全部训练数据的均值和方差。训练完 train 样本后,生成的模型 model 要用来测试样本了,在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。
如果在测试的时候启用了Dropout,那么网络会随机丢弃某几个神经元,这样神经网络每一次生成的结果不是固定的,生成的质量也不稳定。那BN是如何影响网络的?这就要先学习BN具体是如何实现的。
1.1.2.2 核心问题
在CNN训练时,绝大多数都采用mini-batch使用随机梯度下降算法进行训练,那么随着输入数据的不断变化,以及网络中参数不断调整,网络的各层输入数据的分布则会不断变化,那么各层在训练的过程中就需要不断的改变以适应这种新的数据分布,从而造成网络训练困难,难以拟合的问题。
我的理解:针对网络的每一层(例如x层),我们需要保证不管输入什么样的图片(数据),经过网络层层输入到达x层后,它们的数据分布总是类似的,这样我们就能针对x层的特点去修改x层的权重,慢慢拟合,而不是刚适应了某张图片的数据特点,换了一张图片又要大幅度调整。
BN算法解决的就是这样的问题,他通过对每一层的输入进行归一化,保证每层的输入数据分布是稳定的,从而达到加速训练的目的。
1.1.2.3 步骤
-
Standardization:首先对𝑚个𝑥进行 Standardization,得到 zero mean unit variance的分布𝑥̂ 。
其中, x k x^{k} xk表示输入数据的第k维, E [ x k ] E[x^{k}] E[xk]表示该维的平均值, V a r [ x k ] \sqrt{Var[x^{k}]} Var[xk]为标准差。 -
scale and shift:然后再对𝑥̂ 进行scale and shift,缩放并平移到新的分布𝑦,具有新的均值𝛽方差𝛾。
我们思考一个问题,在第一步中,减均值除方差得到的分布是正态分布,我们能否认为正态分布就是最好或最能体现我们训练样本的特征分布呢?不能。这种分布不一定是前面一层要学习到的数据分布,这样强行归一化就会破坏掉刚刚学习到的特征,比如数据本身就很不对称(不符合正态分布),或者激活函数未必是对“方差为1”的数据最好的效果,比如Softmax激活函数,在-1~1之间的函数的梯度不大,那么非线性变换的作用就不能很好的体现。换言之就是,减均值除方差操作后可能会削弱网络的性能。
针对上述问题,BN算法在第二步中设置了两个可学习的变量γ和β,然后用这两个可学习的变量去还原上一层应该学到的数据分布。
1.1.2.4 BN算法在训练和测试时的应用
了解完BN的原理后我又有个疑问,如果训练过程中启用了BN,但是在测试的时候没有启用BN,那么输入的测试图片在x层没有经过处理,它的分布可能不适用于我们训练好的权重,影响模型效果。而且DnCNN最终的输出是噪声图,所以我们不需要担心正则化会影响原始图片或者输入图片,为什么要禁用呢?其实我这里的启用和禁用理解错了。
训练时:首先提取每次迭代时的每个mini-batch的平均值和方差进行归一化,再通过两个可学习的变量恢复要学习的特征。
测试时:没有mini-batch了,即平均值为所有mini-batch的平均值的平均值,而方差为每个batch的方差的无偏估计。
所以BN的启用和不启用,不是说这层的存在与否,而是说这层的参数是否固定。
2 IPT
华为、北大、悉大以及鹏城实验室的研究者提出了一个名为IPT(Image Processing Transformer的预训练 Transformer 模型,用于完成超分辨率、去噪、去雨等底层视觉任务。IPT 具备多个头结构与尾结构用于处理不同的任务,不同的任务共享同一个 Transformer 模块。预训练得到的模型经过微调即可在多个视觉任务上大幅超越对应任务上的当前最好模型。
2.1 网络结构
IPT使用了一个多头多尾共享躯干的结构,应对不同的任务,有针对性不同的头部和尾部,分别使用不同的处理方式。
中间是一个transformer编解码器结构。将头部输出的特征图像unfold成“词向量”形式和位置嵌入相加后输入encoder。encoder是常规的结构,包括一个LN和MSA接残差、LN和FFN接残差(FNN有两层全连接)。解码器结构和常规transformer也差不多,但是多了一个特定任务标签嵌入,加入在decoder输入第一次MSA的Q和K上、第二次Q的上。尾部也是多个针对特定任务的结构,用于还原图像尺寸。
2.1.1 Heads
为了应对不同的图像处理任务,使用了一种多头部结构来分别处理每个任务,其中每个头部由三个卷积层组成。输入图像表示为
x
∈
R
3
×
H
×
W
x∈R^{3×H×W}
x∈R3×H×W,头部生成特征图
f
H
∈
R
C
×
H
×
W
f_H∈R^{C×H×W}
fH∈RC×H×W(通常
C
=
64
C=64
C=64)。计算公式为
f
H
=
H
i
(
x
)
f_H=H^i(x)
fH=Hi(x),其中
i
=
1
,
.
.
.
,
N
t
i=1, ..., N_t
i=1,...,Nt,
N
t
N_t
Nt表示任务数。
2.2.3 Transformer encoder
先将特征分割为patch,每个patch展开为一列向量。用p表示patch的大小,patch的数量为
N
=
H
W
p
2
N=\frac{HW}{p^2}
N=p2HW(序列长度),每个patch的长度为
p
2
×
C
p^2×C
p2×C。再加入位置信息,每个patch都添加了可学习的位置编码,与patch的数量、长度一一对应。
综上,输入编码器的为位置信息
E
p
i
E_{pi}
Epi和特征信息
f
p
i
f_{pi}
fpi。经线性变换生成Q、K、V后进行Multi-Head Self-Attention,再经过残差连接后送入全连接层。
2.2.4 Transformer decoder
在Decoder阶段,先将循环输入的当前序列进行一次Multi-Head Self-Attention操作,再取结果的Q和Encoder输出的K、V进行一次Multi-Head Self-Attention,再由一层全连接层得到最终输出序列。
2.2.5 Tailes
Tails 用于将特征映射到恢复的图像中。类似于Heads,多个Tails用于处理不同的任务。每个Tail的输出是具有与输入图像相同属性的结果图像,最终的的通道数重新被调整为3(RGB通道)。
2.2.6 损失函数
网络选择使用监督学习+对比学习的方式进行训练。整个IPT模型使用的损失函数为(λ用来平衡对比损失与监督损失):
L
I
P
T
=
L
s
u
p
e
r
v
i
s
e
d
λ
⋅
L
c
o
n
t
r
a
s
t
i
v
e
L_{IPT} = L_{supervised} λ · L_{contrastive}
LIPT=Lsupervisedλ⋅Lcontrastive
2.2.6.1 监督学习
用
I
c
o
r
r
u
p
t
e
d
I_{corrupted}
Icorrupted表示噪声图像,
I
c
o
r
r
u
p
t
e
d
=
f
(
I
c
l
e
a
n
)
I_{corrupted}= f(I_{clean})
Icorrupted=f(Iclean)损失值表示为:
损失函数是传统的L1损失。
2.2.6.2 对比学习
由于patch之间的信息很丰富,同一个feature map得到的patch更接近,引入了对比学习。对比学习用于最小化同一张输入图片之间patch的距离、最大化不同输入图片之间patch的距离。
其中,
D
i
1
D_{i1}
Di1和
D
i
2
D_{i2}
Di2表示不同的patch,
j
j
j和
k
k
k表示每个batch中不同的图片。
d
(
a
,
b
)
=
a
T
b
∣
∣
a
∣
∣
⋅
∣
∣
b
∣
∣
d(a,b)=\frac{a^Tb}{||a||·||b||}
d(a,b)=∣∣a∣∣⋅∣∣b∣∣aTb为余弦相似性。
参考来源:
【Pytorch】model.train() 和 model.eval() 原理与用法
深度学习——Batch Normalization算法原理和作用
Batch Normalization的原理和作用
什么是无偏估计?
无偏估计
均值和期望的关系
argparse.ArgumentParser()用法解析
action = “store_true
解析命令行参数的函数parse_known_args()与函数parse_args()的区别
Python Path.get_path方法代码示例
【Python笔记】之Python函数中参数前带*是什么意思
搞定PyTorch中模型保存和加载:torch.save()、torch.load()、torch.nn.Module.load_state_dict()
【PyTorch】torch.manual_seed() 详解
Python中动态导入对象importlib.import_module()的使用
PyTorch中模型的parameters()方法浅析
nn.ModuleList
Python 中下划线 _ 的用法
【Tensorflow】卷积神经网络中strides的参数
当卷积层后跟batch normalization层时为什么不要偏置b
详细解释卷积神经网络CNN中卷积层以及BN层的参数
PyTorch【5】torch.nn模块常用层与激活函数
多层感知器介绍
对PyTorch的dim的理解
通俗讲解pytorch中nn.Embedding原理及使用
nn.MultiheadAttention
Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析
VIT Vision Transformer | 先从PyTorch代码了解
Pytorch 中的 torch.nn.Identity( ) 的作用
nn.LayerNorm()
在 Python 中,将方法名字前面加一个单下划线(_)
pytorch 笔记:torch.nn.init
深度学习参数初始化(二)Kaiming初始化 含代码
用Pytorch中定义一个自注意力层,查看模型参数发现这个层的可训练参数量为0,这正确吗或者为什么?
PyTorch中nn.Module的继承类中方法foward是自动执行的么?
【Torch API】pytorch 中register_buffer()函数详解
Torch.arange函数详解
Python中append和expend方法
浮点运算:双精度、单精度、半精度浮点数计算(FP16/FP32/FP64),浮点和定点
模型参数量化——mode.half()方法与AMP混合精度训练
Python中tqdm库的使用
参考MeanShift
如何理解卷积神经网络CNN的卷积核是四维向量
图像减均值除方差_图像预处理减去均值
数据预处理方式(去均值、归一化、PCA降维)
Python 中 & 是什么意思?
Python math.log() 方法
torch.nn.functional.unfold 用法解读
torch.nn.functional.unfold()
详解Python中的 transpose() 函数
python之transpose()函数
pytorch复习笔记–nn.Embedding()的用法
tensor.repeat()
torch.expand(-1, -1)的理解
【Pytorch】expand()用法==》扩展某个维度
通俗讲解pytorch中nn.Embedding原理及使用
【Torch API】pytorch 中register_buffer()函数详解
nn.LayerNorm的参数
注意力机制——Multi-Head Attention(MHA)
nn.MultiheadAttention详解 – forward()中维度、计算方式
pytorch笔记:nn.MultiheadAttention
多层感知机(MLP)
关于Contiguous()方法
Pytorch踩坑之transpose与view
nn.functional.fold/unfold
EDSR源码笔记
pytorch中tensor * tensor和tensor.mul(tensor)方法
python中lambda函数用法
Python简明讲解filter函数的用法
requires_grad,requires_grad_(),grad_fn区别
lr_scheduler设置
超分算法IPT:Pre-Trained Image Processing Transformer
python os.path.basename() 返回path最后的文件名
os.path.splitext
python中enumerate什么意思
Python中pickle模块用法详解
glob.glob()
HuggingFace学习笔记–Trainer的使用
python:__getitem__方法详解
神经网络中定义网络模型中的forward方法
nn.损失函数
Python中,hasattr()函数的详细介绍以及使用
torch.zeros() 函数介绍
torch.cat()函数的官方解释,详解以及例子
加载数据集的时候经常用到def getitem(self, index):具体怎么理解它呢?
enumerate(dataloader)的使用
pytorch之dataloader,enumerate
PyTorch中的 Dataset、DataLoader 和 enumerate()
Python使用os.path.basename和os.path.dirname获取文件名和路径名
os.path.splitext
ValueError: num_samples should be a positive integer value, but got num_samp=0
np.ascontiguousarray()详解
【torch.from_numpy函数的用法和逻辑】
pytorch中tensor.mul()和mm()和matmul()
图像处理之图像质量评价指标PSNR(峰值信噪比)
完整的模型测试(deom)步骤
optimizer.zero_grad()
权重衰减weight_decay参数从入门到精通
pytorch配置双显卡,使用双显卡跑代码
pytorch断点续训/恢复训练,以及绘制学习率曲线
torch.randn()函数
Python中进程模块的Process函数总结
快速学pytorch之评估模式:model.eval()
python torch.div( ) 函数使用说明
十分钟搞懂梯度消失与梯度爆炸
pytorch如何在预训练过的模型上继续训练?
Pytorch模型保存与加载,并在加载的模型基础上继续训练
pytorch 的 nelement()
Pytorch之clone(),detach(),new_tensor(),copy_()
【3D目标检测代码梳理】3.SECOND网络代码详解
tqdm参数
python-opencv第二期:imwrite函数详解
python的extend函数详解
说明:IPT知识部分有些链接没记录上,很抱歉没能标明出处,如果看到出处或者作者本人看到可以找我反馈我及时修改!