该文是香港中文、腾讯优图以及旷视联合发表于CVPR2018的关于图像去模糊的论文。
由于拍照过程中的抖动、焦点等问题,相机拍摄的图像往往存在模糊现象,这种模糊移除是计算机视觉领域很重要的研究课题之一。
在单帧图像去模糊问题中,由“粗到精”的方案无论是在传统的优化方法还是在深度学习方法中均取得了成功。在该文中,基于多尺度策略,作者提出一种尺度训练网络用于去模糊。相比已有的方法,SRN具有更为简单的网络架构,参数量更少,更易于训练。
在具有复杂运动的大规模模糊数据及上对所提方法进行评估,结果表明:在定量与定性方面,该方法可以取得更高质量的结果。
Abstract
作者探索了一种更有效的架构用于多尺度图像去模糊,所提出的尺度循环网络解决了CNN去模糊方法中的两个重要而常见的问题。论文的创新点主要如下:
- Scale-recurrent Structure。作者提出跨尺度权值共享方案以降低训练难度。这种处理方式的号处理有:(1) 降低参数量;(2) 集成递归模块(这些模块可以捕获跨尺度的有助于复原任务的信息)。
- Encoder-Decoder ResBlock Network。直接应用编解码架构并不能产生最优结果。作者所提架构可以提升不同CNN架构的优点,同时具有非常大的感受野(这对于大尺度运动去模糊极为重要)。
Method
上图给出了作者所提出的尺度递归去模糊网络架构。它以一系列由输入图像按照不同尺度下采样得到的模糊图像作为输入,输出相应的清晰图像,其中全分辨率输出为最终的输出。
Scale-recurrent Network
作者采用了一种新颖的“自粗而精”的跨尺度递归架构。它将不同尺度的清晰图像生成过程视为图像去模糊的子问题。它以模糊图像以及初始去模糊结果作为输入,估计当前尺度下的清晰图像。该问题可以描述为:
其中,i表示尺度索引(i=1表示最精细的尺度)。
上述公式简单的对所提方法进行了定义,实际上,网络架构比较灵活,有很多中可供选择。首先,递归网络有多种选择,如RNN,LSTM以及GRU等等,作者选择了ConvLSTM(因其具有更好的结果)。其次,上采样操作也有多种选择,如转置卷积、PixelShuffle以及插值方法;最后,网络需要进行合理配置以处理不同尺度下的复原。
Encoder-Decoder with ResBlock
编解码指的是一种对称的CNN架构,首先逐渐的将输入逐渐转换为更小分辨率更多通道的特征,然后逐渐转换到原始输入大小的输出。同分辨率下的编码与解码之间可以通过跳过连接方式进行特征组合,跳过连接有助于梯度传播,加速收敛。编解码架构已在多种视觉任务中证实了其有效性。然而,直接将其用到去模糊任务中并非最优选择,作者给出的解释如下。
对于去模糊任务而言,感受野应当足够大以保证可以处理严重的大尺度运动。
堆叠更多的编码/解码模块会造成中间特征的分辨率过小不利于重建;堆叠卷积模块可以提升感受野,但同时会噪声过高的参数量与计算量。
基于上述分析,作者对编解码架构进行了几个改进以适配其架构。改进包含以下几个方面:
- 通过引入残差模块(无BN层)改进编码/解码模块,其编码模块包含一个stride=2的卷积,外加几个残差模块,分辨率降低同时通道数加倍;解码模块与编码模块堆成,即包含几个残差模块,外加一个转置卷积,上采样的同时通道率减半。
- 在网络内部设置递归模块。在隐状态模块中插入卷积层以连接连续尺度特征,这里卷积层的尺度为
。
修正后的网络描述为:
Loss
在训练过程中,对每个尺度采用欧式损失。整体损失定义为:
作者还提到:曾尝试过对抗损失以及TV损失,但上述损失已经具有足够好的结果。
Experiments
在实验过程中,选用GoPro数据(它包含3214对数据)进行进行训练。选用Adam(0.9, 0.999)优化器,学习率以指数方式从0.0001经由2000epoch以0.3指数下降到1e-6.大概花费72小时时间。训练过程中输入块大小为256x256,BatchSize设置为16。同时对ConvLSTM模块中的权值进行梯度裁剪(全局范数3)以稳定训练。在测试过程中,对720p大小的图像,其推理耗时为1.87s。
下面给出了在基准数据集上所提方法与SOTA方法的性能对比。
Conclusion
作者解释了什么架构适合于“自粗而精”的图像去模糊机制,并提出一种尺度递归架构。相比已有多尺度去模糊方法,该架构具有更少的参数、更易训练。在定性与定量方面,该方法均取得了SOTA性能。作者认为:这种架构同样可以应用其它图像处理任务中。
参考代码
import torch
import torch.nn as nn
import torch.nn.functional as F
# resnet-block
class ResBlock(nn.Module):
def __init__(self, inc, ksize):
super(ResBlock, self).__init__()
pad = (ksize - 1)//2
self.net = nn.Sequential(nn.Conv2d(inc,inc,ksize,1,pad),
nn.ReLU(),
nn.Conv2d(inc,inc,ksize,1,pad))
def forward(self, x):
return x + self.net(x)
# create as the author's code
class ConvLSTM(nn.Module):
def __init__(self):
super(ConvLSTM, self).__init__()
self.conv = nn.Conv2d(256, 512, 3, 1, 1)
def forward(self, x, h, c):
res = self.conv(torch.cat([x, h], 1))
i, j, f, o = res.split(128, 1)
i = i.sigmoid()
j = j.tanh()
f = torch.sigmoid(f + 1.0)
o = o.sigmoid()
new_c = c * f + i * j
new_h = new_c.tanh() * o
return new_h, new_c
def init_hidden_state(self, N, C, H, W):
return (torch.zeros(N,C,H,W),torch.zeros(N,C,H,W))
# create as the author's code
class LstmNet(nn.Module):
def __init__(self):
super(LstmNet, self).__init__()
eblk1 = [nn.Conv2d(2, 32, 5, 1, 2), nn.ReLU(),
ResBlock(32, 5),
ResBlock(32, 5),
ResBlock(32, 5)]
eblk2 = [nn.Conv2d(32, 64, 5, 2, 2), nn.ReLU(),
ResBlock(64, 5),
ResBlock(64, 5),
ResBlock(64, 5)]
eblk3 = [nn.Conv2d(64, 128, 5, 2, 2), nn.ReLU(),
ResBlock(128, 5),
ResBlock(128, 5),
ResBlock(128, 5)]
self.encoder1 = nn.Sequential(*eblk1)
self.encoder2 = nn.Sequential(*eblk2)
self.encoder3 = nn.Sequential(*eblk3)
dblk3 = [ResBlock(128, 5),
ResBlock(128, 5),
ResBlock(128, 5),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU()]
dblk2 = [ResBlock(64, 5),
ResBlock(64, 5),
ResBlock(64, 5),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.ReLU()]
dblk1 = [ResBlock(32, 5),
ResBlock(32, 5),
ResBlock(32, 5),
nn.Conv2d(32, 1, 5, 1, 2)]
self.decoder1 = nn.Sequential(*dblk1)
self.decoder2 = nn.Sequential(*dblk2)
self.decoder3 = nn.Sequential(*dblk3)
self.convlstm = ConvLSTM()
def _step_(self, x, h, c):
e32 = self.encoder1(x)
e64 = self.encoder2(e32)
e128 = self.encoder3(e64)
h, c = self.convlstm(e128, h, c)
d64 = self.decoder3(h)
d32 = self.decoder2(d64 + e64)
d3 = self.decoder1(d32 + e32)
return d3, h, c
def forward(self, x1, x2, x4):
N, _, H, W = x4.size()
h, c = self.convlstm.init_hidden_state(N,128,H//4,W//4)
i4, h, c = self._step_(torch.cat([x4, x4], 1), h, c)
output = [i4]
h = F.interpolate(h, scale_factor=2, mode='bilinear')
c = F.interpolate(c, scale_factor=2, mode='bilinear')
i4 = F.interpolate(i4, scale_factor=2, mode='bilinear')
i2, h, c = self._step_(torch.cat([x2, i4], 1), h, c)
output.append(i2)
h = F.interpolate(h, scale_factor=2, mode='bilinear')
c = F.interpolate(c, scale_factor=2, mode='bilinear')
i2 = F.interpolate(i2, scale_factor=2, mode='bilinear')
i1, h, c = self._step_(torch.cat([x1, i2], 1), h, c)
output.append(i1)
return output
def demo():
x1 = torch.randn(4, 3, 128, 128)
x2 = torch.randn(4, 3, 64, 64)
x4 = torch.randn(4, 3, 32, 32)
model = Net().eval()
with torch.no_grad():
output = model(x1, x2, x4)
print('1/4: {}'.format(output[0].size()))
print('1/2: {}'.format(output[1].size()))
print('1/1: {}'.format(output[2].size()))
if __name__ == "__main__":
demo()
作者一共提供了三组模型,分别是输入彩色无LSTM模块,输入灰度图像无LSTM模块,输入灰度图像有LSTM模块,但没有输入彩色有LSTM模块的模型。下图为本人将原作者tensorflow模型转为pytorch后的LSTMNet去模糊效果图,可以确认该模型是没有问题的。该方法确实能提供不错的去模糊效果,但算法耗时有点多,基本不用考虑在手机上的应用了。
后记
20190816:经查,因某些未知原因,尽管tensorflow与pytorch模型参数完全相同,但是会因为resize、stride等原因导致最终输出的结果存在差异,导出模型效果比原始tf模型效果要差一些。初步猜测与框架底层处理机制有关。差异有以下两点:
(1) resize结果不一致,tf中的resize参数是非对齐模型,但与pytorch中的对齐模式类似,而opencv中无该对齐参数;这种差异导致了结果上的差异;
(2) 将tf的输入导出送入到pytorch模型中,encoder1的结果完全一致,但经encoder2中的stride=2的卷积后结果则出现了差异。经查,两个卷积的参数则是完全一致的。为什么导致了这种现象目前未知。如有小伙伴碰到过类似现象并解决者,烦请不吝赐教。
注:作者在encoder2中的stride卷积代码:
slim.conv2d(conv1_4, 64, [5, 5], stride=2, scope='enc2_1'),
它包含卷积+relu操作,而在pytorch中将其写成了两个。个人认为应该没有差异才对。
20190830:经过在多个模型上测试发现:tensorflow的stride=2卷积计算方式与pytorch的不一致。pytorch的计算方式更符合本人的认识,本人不确定是不是tf的版本导致的。但在1.17.0版本上,测试多个tf模型均存在这个问题。看来以后得彻底放弃tf了。
只是简单的判断了以上两点不同,后续的没有进行检查。但整个网络架构应该是没有问题的,基于该架构在GoPro数据集上进行finetune应该能达到匹配作者的tf模型的效果。
在github上发现有小伙伴对该网络进行了复现,但效果比作者的稍差。查看代码发现:其ConvLSTM模块与作者的不太一样,但影响应该不是很大。
20190905补充:
转换模型的差异原因已经找到:tensorflow与pytorch在stride>1时的padding方式会有区别。pytorch中的padding方式比较单纯(padding与stride无关);而tensorflow中padding与stride存在关联性。tensorflow中的padding方式见如下链接,感兴趣者可以参考该链接更新一下上述stride=2中的卷积部分即可。
https://www.tensorflow.org/versions/r1.12/api_guides/python/nn#Convolutionwww.tensorflow.org