目录
论文发表于CVPR2019。论文 github
去雨深度模型越来越复杂多样,难以分析不同网络模块的作用。该论文则从网络结构、输入输出、损失函数这几个方面提出更好且更简单的baseline。
模型
不考虑建立更深、更复杂的模型,作者选择分多个阶段(multi-stage)解决问题,在每个阶段部署一个浅层ResNet
考虑到堆叠网络会大量增加参数且造成过拟合,作者利用阶段之间的递归计算(inter-stage recursive computation)使多个阶段共享网络参数,还提出了使用进一步减少参数的阶段内部的递归计算(intra-stage recursive computation)
输入输出
输入:每个阶段的结果和原始雨图的拼接作为每个ResNet的输入
输出:残差图像。直接用模型学习雨图中的干净背景也是可行的
网络结构
以有5个ResBlock的浅层残差网络为基础,作者建立以下几种模型:
PRN:Progressive Residual Network
将一个ResNet重复在
T
T
T个阶段上展开,网络参数在不同阶段重复使用
每个阶段的网络具体包含以下部分:
- f i n f_{in} fin:Conv+ReLU,接受上个阶段输出的图像和原始雨图的拼接作为输入
- f r e s f_{res} fres:5个ResBlock,提取深度特征表示
-
f
o
u
t
f_{out}
fout:Conv,输出去雨结果
每个阶段 T T T的推断过程用以下公式描述:
源代码:
class PRN(nn.Module):
def __init__(self, recurrent_iter=6, use_GPU=True):
super(PRN, self).__init__()
self.iteration = recurrent_iter
self.use_GPU = use_GPU
# f_in
self.conv0 = nn.Sequential(
nn.Conv2d(6, 32, 3, 1, 1),
nn.ReLU()
)
# f_res, 五个ResBlock
self.res_conv1 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv2 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv3 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv4 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv5 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
# f_out
self.conv = nn.Sequential(
nn.Conv2d(32, 3, 3, 1, 1),
)
def forward(self, input):
x = input
x_list = []
for i in range(self.iteration):
x = torch.cat((input, x), 1)
x = self.conv0(x)
resx = x
x = F.relu(self.res_conv1(x) + resx) # 残差连接 下亦如是
resx = x
x = F.relu(self.res_conv2(x) + resx)
resx = x
x = F.relu(self.res_conv3(x) + resx)
resx = x
x = F.relu(self.res_conv4(x) + resx)
resx = x
x = F.relu(self.res_conv5(x) + resx)
x = self.conv(x)
x = x + input # 残差学习
x_list.append(x)
return x, x_list
PReNet: Progressive Recurrent Network
在PRN的基础上加入一个循环(recurrent)层挖掘不同阶段之间的深层特征,即下图中的蓝绿色模块
每个阶段与PRN唯一的不同之处就是循环层的引入。
f
r
e
c
u
r
r
e
n
t
f_{recurrent}
frecurrent以该阶段
f
i
n
f_{in}
fin的输出和上一个阶段的循环层的状态为输入,可以用LSTM或GRU实现。作者指出LSTM会更好
每个阶段
T
T
T的推断过程用以下公式描述:
源代码:
class PReNet(nn.Module):
def __init__(self, recurrent_iter=6, use_GPU=True):
super(PReNet, self).__init__()
self.iteration = recurrent_iter
self.use_GPU = use_GPU
# f_in
self.conv0 = nn.Sequential(
nn.Conv2d(6, 32, 3, 1, 1),
nn.ReLU()
)
# f_res
self.res_conv1 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv2 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv3 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv4 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv5 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
# f_recurrent(LSTM)
self.conv_i = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
self.conv_f = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
self.conv_g = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Tanh()
)
self.conv_o = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
# f_out
self.conv = nn.Sequential(
nn.Conv2d(32, 3, 3, 1, 1),
)
def forward(self, input):
batch_size, row, col = input.size(0), input.size(2), input.size(3)
x = input
h = Variable(torch.zeros(batch_size, 32, row, col))
c = Variable(torch.zeros(batch_size, 32, row, col))
if self.use_GPU:
h = h.cuda()
c = c.cuda()
x_list = []
for i in range(self.iteration):
x = torch.cat((input, x), 1) # 拼接原始输入和上一阶段的输出作为输入
x = self.conv0(x)
x = torch.cat((x, h), 1) # 拼接f_in的结果和recurrent层的状态
i = self.conv_i(x)
f = self.conv_f(x)
g = self.conv_g(x)
o = self.conv_o(x)
c = f * c + i * g
h = o * torch.tanh(c)
x = h
resx = x
x = F.relu(self.res_conv1(x) + resx)
resx = x
x = F.relu(self.res_conv2(x) + resx)
resx = x
x = F.relu(self.res_conv3(x) + resx)
resx = x
x = F.relu(self.res_conv4(x) + resx)
resx = x
x = F.relu(self.res_conv5(x) + resx)
x = self.conv(x)
x = x + input # 残差学习。作者测试LSTM和GRU性能时用的模型PReNet_LSTM则是采用直接映射
x_list.append(x)
torch.cuda.empty_cache()
return x, x_list
PRNt r _{r} r & PReNet r _{r} r
利用阶段内部(intra-stage)的递归计算,重复使用同一个ResBlock(下图中右边的Recursive ResBlocks),减少网络参数的同时保持SOTA水平,在模型大小和去雨性能之间做了折中。
源代码:
class PReNet_r(nn.Module):
def __init__(self, recurrent_iter=6, use_GPU=True):
super(PReNet_r, self).__init__()
self.iteration = recurrent_iter
self.use_GPU = use_GPU
# f_in
self.conv0 = nn.Sequential(
nn.Conv2d(6, 32, 3, 1, 1),
nn.ReLU()
)
# f_res,只有一个ResBlock
self.res_conv1 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
# f_recurrent
self.conv_i = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
self.conv_f = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
self.conv_g = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Tanh()
)
self.conv_o = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
# f_out
self.conv = nn.Sequential(
nn.Conv2d(32, 3, 3, 1, 1),
)
def forward(self, input):
batch_size, row, col = input.size(0), input.size(2), input.size(3)
#mask = Variable(torch.ones(batch_size, 3, row, col)).cuda()
x = input
h = Variable(torch.zeros(batch_size, 32, row, col))
c = Variable(torch.zeros(batch_size, 32, row, col))
if self.use_GPU:
h = h.cuda()
c = c.cuda()
x_list = []
for i in range(self.iteration):
x = torch.cat((input, x), 1) # 拼接原始输入和上一阶段的输出作为输入
x = self.conv0(x)
x = torch.cat((x, h), 1) # 拼接f_in的结果和recurrent层的状态
i = self.conv_i(x)
f = self.conv_f(x)
g = self.conv_g(x)
o = self.conv_o(x)
c = f * c + i * g
h = o * torch.tanh(c)
x = h
for j in range(5):
resx = x
x = F.relu(self.res_conv1(x) + resx)
x = self.conv(x)
x = input + x
x_list.append(x)
return x, x_list
网络细节
- 阶段
T
T
T选取为6, 6个阶段(6-stage)的PReNet在第一个阶段即可去掉大部分雨水,剩下的雨线逐步被移除。
- 所有的Conv层的filter大小为3x3,padding为1
- f i n f_{in} fin输入channel为6,输出channel为32; f r e c u r r e n t f_{recurrent} frecurrent和 f r e s f_{res} freschannel为32; f o u t f_{out} fout输入channel为32,输出channel为3
- f r e s f_{res} fres由五个ResBlock组成
损失函数
近期去雨模型中很多都使用了混合损失函数(如MSE+SSIM)和对抗损失。该论文作者指出,这些损失增加了调整超参的负担。由于渐进式网络结构的存在,单独的MSE或者负SSIM已经足够训练PRN和PReNet达到理想效果。
对最后一个阶段的输出
x
T
x^{T}
xT进行监督学习,
x
g
t
x^{gt}
xgt是对应的ground truth无雨图像。
MSE损失函数如下:
负SSIM损失函数如下:
还可以对每个中间结果进行递归监督。
λ
t
\lambda_{t}
λt是阶段
t
t
t的折中(tradeoff)参数。递归监督在t=T时无法实现性能提升,但是可以在早期阶段生成视觉上较为满意的结果(见实验Ablation study的损失函数部分)
实验
数据增广
以步长为80,将图像裁剪成100x100的块(patch)进行数据增广。以Rain100L数据集为例,训练图像从200张增加到6000张
Ablation Study
使用Rain100H数据集中的1254张训练图像(原1800张,去除了背景重复的图像)和100张测试图像进行ablation study
损失函数
负SSIM v.s. MSE
训练两个PReNET模型,分别以最小化MSE损失和负SSIM损失为目标,在Rain100H上的SSIM和PSNR结果显示,以负SSIM为损失函数的模型在两个指标上都表现更好
单次 v.s. 循环使用
损失函数可以只应用于最后一个阶段或者递归地在各个阶段使用。对于后者,将tradeoff参数设置为最后一个阶段大于其他阶段。实验结果显示,这两种方式视觉上不相上下,评分上后者略逊于前者。这说明单独的损失足够训练该论文的渐进式网络。
另外,下图中,后一种方式训练的模型的中间阶段的SSIM和PSNR评分高于前一种方式,也就是说在计算资源有限时,使用后一种方式训练,在任何一个阶段停下来,都能获得较好的结果
网络结构 & 输入输出
PReNet x _{x} x, PReNet-LSTM, PReNet-GRU学习雨图到干净背景图的直接映射,PReNet则为采用残差映射,引入LSTM层,将每个阶段的结果和原始雨图的拼接作为输入的最终模型
对于输入,PReNet
x
_{x}
x比PReNet在PSNR和SSIM评分都要低,说明在每个阶段接受原始雨图作为输入是有好处的;对于输出,从实验结果来看,直接映射也能取得理想的效果
表三
表四
对于模型结构,从表四PRN和PReNET的对比结果来看,循环层的引入对去雨效果有提升;从表三PReNet-LSTM和PReNet-GRU的评分来看,LSTM比GRU更好;表四显示,重复使用一个ResBlock的PRN r _{r} r和PReNet r _{r} r比PRN和PReNet效果稍差,但是也达到了较为不错的水平
阶段个数
太大的阶段个数T会使得模型难以训练,通过实验,作者选择T=6