[论文解读-单图像去雨]Spatial Attentive Single-Image Deraining with a High Quality Real Rain Dataset
目录
1.相关链接
chrome-extension://aajldohlagodeegngemjjgmabejbejli/pdf/viewer.html?file=https://arxiv.org/pdf/1904.01538v2.pdf
https://codeload.github.com/stevewongv/SPANet/zip/master
2.发现问题
现有的合成雨模式单一缺乏多样性,缺少真实雨的数据集,之所以会产生缺乏真实雨数据的情况是因为难以同时获得真实的有雨和无雨的图像对。
3.贡献
- 提出了一种从真实有雨的连续帧中生成无雨清晰图像的方法,该方法结合了时间先验和人类监督(雨滴不可能长时间覆盖在同一个像素点上,视频中某一个像素点的值应该在真实值附近波动)。
- 构建了一个包含约29.5K的有雨和无雨图像对的大规模数据集。
- 提出了SPANet(Spatial Attentive Network),该方法从局部到全局消除雨滴。
4.从有雨视频得到无雨的清晰图像
上图中红色柱是人类选出的被雨滴覆盖的像素,蓝色是人类选出的没有被雨滴覆盖的像素,显然上图中该点的真实像素值应该在4附近。
我们假设每一帧图像包含
l
l
l 个像素点,假设选取
N
N
N 帧连续的有雨图像,我们可以定义每一帧上同一个像素点值的集合
O
l
=
{
o
1
l
,
.
.
.
,
o
N
l
}
O_l=\{o_{1l},...,o_{Nl}\}
Ol={o1l,...,oNl},根据上图我们可以很容易考虑到选择集合中出现次数最多的像素值(此处不是用RGB值计算出现次数,而是用亮度值,作者是把图像转换到YCbCr空间得到亮度通道后再进行统计)作为背景的估计值会比较合理,也就是计算集合
O
l
O_l
Ol 的众数,我们可以通过以下等式来表达:
ϕ l = Φ ( O l ) , (1) \phi_l=\Phi(O_l),\tag 1 ϕl=Φ(Ol),(1)
其中
Φ
\Phi
Φ 是mode函数,返回集合中出现次数最多的数值。但在式(1)中计算
ϕ
l
\phi_l
ϕl 时并没有考虑邻域信息,当图像中包含稠密雨时,计算得到的图像中会包含许多噪声。所以进一步计算
O
l
O_l
Ol 中的
ϕ
l
\phi_l
ϕl 对应的百分位数范围
(
R
l
m
i
n
,
R
l
m
a
x
)
(R_l^{min},R_l^{max})
(Rlmin,Rlmax) ,下图中(b)直观展示了此处百分比范围的含义。
R
l
min
=
100
%
N
∑
i
=
1
N
{
1
∣
o
i
l
<
ϕ
l
}
R
l
max
=
100
%
N
∑
i
=
1
N
{
1
∣
o
i
l
>
ϕ
l
}
(2)
\begin{aligned}&R_{l}^{\min }=\frac{100 \%}{N} \sum_{i=1}^{N}\left\{1 \mid o_{i l}<\phi_{l}\right\}\\&R_{l}^{\max }=\frac{100 \%}{N} \sum_{i=1}^{N}\left\{1 \mid o_{i l}>\phi_{l}\right\}\end{aligned}\tag 2
Rlmin=N100%i=1∑N{1∣oil<ϕl}Rlmax=N100%i=1∑N{1∣oil>ϕl}(2)
通过以上方法,我们可以得到连续帧中每一个像素点的亮度值的众数的百分比范围(换句话说,就是得到了每个像素点对应的
ϕ
\phi
ϕ 的百分比范围),我们将每个百分比范围作在图上就能得到上图(a)中的多个蓝色条,再取一根线让它穿过尽量多的蓝色条,也就是图中红色虚线,这就是我们得到的百分比值,然后我们用这个值对每一个像素点在时间(帧)维度上做百分比滤波(例如中值滤波就是百分比值为50%的百分比滤波,每个像素点用同样的值做滤波)。
不难想到,这种方法的效果与
N
N
N 的取值有关,作者通过简单地多次试验,得到结论:稀疏或者正常雨时
N
N
N 取20-100,稠密雨时
N
N
N 取200-300。
5.网络结构
5.1.主干
将带雨滴的图片送入上图中(a)所示结构,会先通过一个卷积层(个人认为这里的卷积主要是为了调整通道数,另外也能一定程度上增大感受野),之后是三个标准的残差块,用于特征提取。接着是四个SAB,如上图中(b)所示,SAB中包含三个SARB和一个SAM。SAM生成的attention map用来指导SARB去除雨的条纹。要注意的的是,四个SAB中的SAM的权重是共享的,SARB的权重不共享。紧接着是两个串联的标准残差块和一个卷积层来重构清晰背景图。下面会放上主干部分及SAM部分的代码帮助理解。
5.2.SAM
feature map送入SAM中后,首先是进入注意力模块计算注意力权重(上图中d最上边的分支,输出有四份独立的attention map,attention map在四个分支上没有共享,在两个阶段上是共享的)。在主分支上,首先是对输入的feature map做3x3卷积,然后送入四个方向的IRNN,也就是图中带箭头的小方框。然后将IRNN的输出与注意力权重相乘,之后将四个方向的输出按通道维度相拼接,然后就重复之前从3x3卷积开始的过程,两个阶段结构相似但权重不共享。最后是一个relu激活的卷积层与一个sigmoid激活的卷积层,最后的输出作为主干中用到的attention map。以向右移动的IRNN为例,其计算方法见式(3)(至于为什么是减一不是加一,个人认为是因为在图像处理中,一般是以最左上角像素点为坐标原点,并且向右移动的IRNN处理是从左到右的)。
h
i
,
j
←
max
(
α
dir
h
i
,
j
−
1
+
h
i
,
j
,
0
)
(3)
h_{i, j} \leftarrow \max \left(\alpha_{\text {dir}} h_{i, j-1}+h_{i, j}, 0\right)\tag 3
hi,j←max(αdirhi,j−1+hi,j,0)(3)
计算右移的IRNN时是将输入中的每一行看成一个输入,一行中的每一列作为不同时刻的输入。feature map送入IRNN之后先用一个3x3卷积作为输入到影藏层的权重(类似RNN中的U),这一步处理就能得到
h
h
h ,然后再对
h
h
h 应用式(3)就能得到四个方向的输出。
用通俗的话来解释式(3):就是对当前的
h
i
,
j
h_{i, j}
hi,j 加上上一个
h
h
h 即
h
i
,
j
−
1
h_{i, j-1}
hi,j−1(因为向右)的一部分。式(3)中
α
dir
\alpha_{\text {dir}}
αdir 类似RNN中的W。
IRNN细节可以参考论文《Inside-outside net: Detecting objects in context withskip pooling and recurrent neural networks》。
class SAM(nn.Module):
def __init__(self,in_channels,out_channels,attention=1):
super(SAM,self).__init__()
self.out_channels = out_channels
self.irnn1 = Spacial_IRNN(self.out_channels)
self.irnn2 = Spacial_IRNN(self.out_channels)
self.conv_in = conv3x3(in_channels,in_channels)
self.conv2 = conv3x3(in_channels*4,in_channels)
self.conv3 = conv3x3(in_channels*4,in_channels)
self.relu2 = nn.ReLU(True)
self.attention = attention
if self.attention:
self.attention_layer = Attention(in_channels)
self.conv_out = conv1x1(self.out_channels,1)
self.sigmod = nn.Sigmoid()
def forward(self,x):
if self.attention:
weight = self.attention_layer(x)
out = self.conv_in(x)
top_up,top_right,top_down,top_left = self.irnn1(out)
# direction attention
if self.attention:
top_up.mul(weight[:,0:1,:,:])
top_right.mul(weight[:,1:2,:,:])
top_down.mul(weight[:,2:3,:,:])
top_left.mul(weight[:,3:4,:,:])
out = torch.cat([top_up,top_right,top_down,top_left],dim=1)
out = self.conv2(out)
top_up,top_right,top_down,top_left = self.irnn2(out)
# direction attention
if self.attention:
top_up.mul(weight[:,0:1,:,:])
top_right.mul(weight[:,1:2,:,:])
top_down.mul(weight[:,2:3,:,:])
top_left.mul(weight[:,3:4,:,:])
out = torch.cat([top_up,top_right,top_down,top_left],dim=1)
out = self.conv3(out)
out = self.relu2(out)
mask = self.sigmod(self.conv_out(out))
return mask
class SPANet(nn.Module):
def __init__(self):
super(SPANet,self).__init__()
self.conv_in = nn.Sequential(
conv3x3(3,32),
nn.ReLU(True)
)
self.SAM1 = SAM(32,32,1)
self.res_block1 = Bottleneck(32,32)
self.res_block2 = Bottleneck(32,32)
self.res_block3 = Bottleneck(32,32)
self.res_block4 = Bottleneck(32,32)
self.res_block5 = Bottleneck(32,32)
self.res_block6 = Bottleneck(32,32)
self.res_block7 = Bottleneck(32,32)
self.res_block8 = Bottleneck(32,32)
self.res_block9 = Bottleneck(32,32)
self.res_block10 = Bottleneck(32,32)
self.res_block11 = Bottleneck(32,32)
self.res_block12 = Bottleneck(32,32)
self.res_block13 = Bottleneck(32,32)
self.res_block14 = Bottleneck(32,32)
self.res_block15 = Bottleneck(32,32)
self.res_block16 = Bottleneck(32,32)
self.res_block17 = Bottleneck(32,32)
self.conv_out = nn.Sequential(
conv3x3(32,3)
)
def forward(self, x):
out = self.conv_in(x)
out = F.relu(self.res_block1(out) + out)
out = F.relu(self.res_block2(out) + out)
out = F.relu(self.res_block3(out) + out)
Attention1 = self.SAM1(out)
out = F.relu(self.res_block4(out) * Attention1 + out)
out = F.relu(self.res_block5(out) * Attention1 + out)
out = F.relu(self.res_block6(out) * Attention1 + out)
Attention2 = self.SAM1(out)
out = F.relu(self.res_block7(out) * Attention2 + out)
out = F.relu(self.res_block8(out) * Attention2 + out)
out = F.relu(self.res_block9(out) * Attention2 + out)
Attention3 = self.SAM1(out)
out = F.relu(self.res_block10(out) * Attention3 + out)
out = F.relu(self.res_block11(out) * Attention3 + out)
out = F.relu(self.res_block12(out) * Attention3 + out)
Attention4 = self.SAM1(out)
out = F.relu(self.res_block13(out) * Attention4 + out)
out = F.relu(self.res_block14(out) * Attention4 + out)
out = F.relu(self.res_block15(out) * Attention4 + out)
out = F.relu(self.res_block16(out) + out)
out = F.relu(self.res_block17(out) + out)
out = self.conv_out(out)
return Attention1 , out
6.训练细节
6.1.损失函数
L total = L 1 + L S S I M + L A t t (4) \mathcal{L}_{\text {total}}=\mathcal{L}_{1}+\mathcal{L}_{S S I M}+\mathcal{L}_{A t t}\tag 4 Ltotal=L1+LSSIM+LAtt(4)
其中
L
1
\mathcal{L}_{\text {1}}
L1表示重构误差。
L
S
S
I
M
=
1
−
SSIM
(
P
,
C
)
\mathcal{L}_{S S I M}=1-\operatorname{SSIM}(\mathcal{P}, \mathcal{C})
LSSIM=1−SSIM(P,C),用于约束结构相似性,其中
P
\mathcal{P}
P 是预测结果,
C
\mathcal{C}
C 是清晰图像。
L
a
t
t
=
∥
A
−
M
∥
2
2
\mathcal{L}_{a t t}=\|\mathcal{A}-\mathcal{M}\|_{2}^{2}
Latt=∥A−M∥22,
A
\mathcal{A}
A 是来自第一个SAM的attention map,
M
\mathcal{M}
M 是一个指示雨条纹位置的binary map(因为有无雨图像和有雨图像对,相减再二值化即可得到binary map),binary map中用1表示被雨覆盖的像素点,用0表示没被覆盖的像素点。
个人认为从本质上来说获取attention map是一个分类问题,使用BCE来计算
L
a
t
t
\mathcal{L}_{a t t}
Latt 会不会显得更合理呢?另外,因为结合了三项指标共同构成损失函数,但我们无法量化各项指标的重要程度,所以我认为可以引入对应的权重作为超参数来权衡各项指标的重要程度。
6.2.其他细节
硬件配置:
- E5-2640 V4
- 4个NVIDIA Titan V
优化器:
- Adam
batchsize:
- 16
数据扩增:
- 随机放缩
- 随机裁剪
学习率:
- 以0.005开始
- 30k个iterations后变0.0005
- 共40k个iterations