今天要介绍的论文是一篇2017年发表于CVPR的论文《Learning from Simulated and Unsupervised Images through Adversarial Training》,第一作者为:Ashish Shrivastava,马里兰大学计算机视觉博士,原文链接为:
http://www.im2m.com.cn/uploads/5/file/public/201712/20171213100812_02i7mq1r98.pdf,如图1所示,它讲述的是一种基于无监督学习的方式,使用未标记的真实数据,通过对抗性训练改善合成图像,使合成图像更具真实性的方法。
图1 Simulated+Unsupervised (S+U) learning
背景
图2 生成对抗网络GAN
论文中提出的“模拟+无监督学习”(Simulated + Unsupervised Learning)的学习方法,使用的也是GANs的逻辑,我们先来简单回顾一下GANs模型。
如图2所示,GAN模型简单来说就是让两个网络相互博弈,玩一个“猫鼠游戏”。
生成器网络G根据输入向量生成假图片。
鉴别器网络D同时观察真实和假造的数据,判断这个图片到底是不是真的数据。
G尝试用自己的生成的假图片来“蒙骗”D,而D也不断提高自己鉴别真伪的水平。这样G的图像合成能力和D的鉴别能力都会越来越高超。
1
如图3所示,本文的“对抗训练”是对原有的GANs模型做了稍加的修改,形成了自己的SimGANs模型。其中Sim是单词 Simulator的缩写,即模拟器的意思。SimGANs主要包括了三部分:模拟器S(Simulator)和改善器R(Refiner),以及鉴别器D(Discriminator)。模拟器生成合成图像,改善器优化合成图像使其更像真实图片,最后再由鉴别器做识别训练。
类似于GANs网络,模型的训练主要包含了两个核心模块,改善器R和鉴别器D。改善器输入合成数据,输出改善结果。鉴别器则是判断输入的图像是真实数据还是经过改善的合成数据。
图3 SimGANs模型
其实这里的改善器R就相当于GANs中的生成器G。在R与D两者的相互博弈中,我们期望经过若干次迭代后得到的改善器R,可以让模拟器所生成合成图像在外观上更像一个真实的图像,同时保存从模拟器中得到的注释信息。为此,本文定义了如下的目标损失函数。
Loss Function
整个模型的训练总体上可以分为鉴别器网络的训练和改善网络的训练。
如公式1所示,鉴别器网络训练的目标损失函数跟传统GANs相比并没有太大的差异,x˜是改善器R所改善的图像,y是未标记的真实数据。在该损失函数中,第一部分是希望改善后的图片尽可能真实,第二部分希望未标注真实图片输入后的值尽可能小。整个网络训练的目的是要让鉴别器尽可能区分出真实图像与合成图像。
公式1 鉴别器网络D目标损失函数
而对于改善网络呢,我们可以用公式2来表示,可以理解为一个函数,其中θ是函数参数,x为输入合成图像,输出改善图像x˜。
公式2 R模型
如公式3所示,改善网络的目标损失函数由两部分组成,第一部分real部分的展开如下,它主要是增加了合成图像的真实感,该部分损失的计算需要用到训练过程中鉴别器网络的分类结果,利用已经训练过的鉴别器网络来更新改善网络;第二部分reg部分则是通过最小化合成图像和改善图像之间的差异以确保对应的注释/标签信息不变。例如,对于凝视估计,学习的变换不应改变眼睛凝视方向。这一部分损失文章中使用的是自我正则化损失,其中的正则化项是L1正则化。在整个训练中我们的目的就是通过最小化损失函数来迭代更新R的参数θ
公式3 改善网络R的目标损失函数
算法
总体来说模型训练的目的便是最小化这两个目标损失函数来获取到改善网络和鉴别器网络的参数。论文中也总结了模型参数的训练过程,如算法1所示,在迭代中通过随机梯度下降算法SGD交替更新改善网络和鉴别器网络的参数,使得损失函数得到收敛,从而得到较优的参数。
以上提到的整个SimGans模型从本质上来说其实只是将生成对抗式网络的思想用于改善合成图像的质量,并对目标损失函数做了一定的修改。除此之外,本文还对整个训练框架做了一些关键的改动
算法1
Local Adversarial Loss
一方面,加入局部对抗性损失。当我们训练一个单一强壮的鉴别器网络时,改善器网络R会倾向于过度强调某些图像特征,可以理解为通过局部作假的方式来愚弄当前的鉴别器网络,忽视了整张图片的特征,导致漂移和伪影。而实际上呢,真实数据的任何局部数据块在鉴别器看来都应该是真实的。因此,为了避免训练过程中鉴别器过强,导致生成假的图片,论文中提出让鉴别器网络对图片的局部数据块进行损失计算,如图4所示,保证改善后图片的每一个区域都应该趋向于真实的图片,而不是整张图片进行直接计算,也就是让每张图片有多个真实感损失值。
具体来说,并不是定义一个全局的鉴别器,而是分别去鉴别图片的每一个区域。鉴别网络输出WxH个区域的值,然后将每个区域进行交叉熵计算,得到局部损失,并进行累加,将这些局部对抗损失平均化以获得更平衡的全局对抗损失,从而提高整个网络的性能。
图4 Local Adversarial Loss
如图5所示,本文也针对该改进方法进行了实验比对。证明了使用局部对抗性损失的重要性。左边是在手势识别中用全局对抗性损失进行训练所生成的一个示例图像,可以发现生成的图像手边缘含有明显的不切实际的深度边界伪影。右边是用局部对抗性损失进行训练所生成的同一图像,看起来更现实,少了许多的噪声。
图5 局部对抗性损失的实验对比
Updating Discriminator using a History of Refined Images
而在对抗训练中存在着另外一个问题就是鉴别器网络只会注意到当前最新的改善图片,这会导致两个问题,一是整个训练的分歧;二是改善网络会生成鉴别器网络已经忘记了的重复的图片,这会导致整个鉴别器网络的稳定性变差。理论上对于鉴别器而言,任何时候改善网络所生成的图片都应该判别为假,基于此,为了提高对抗性训练的稳定性,本文提出了用之前训练过的旧的改善图片来再次更新鉴别器网络,而不是只用当前批量中的改善图像来更新鉴别器。
如图6所示,具体改善方式体现在算法中,是在训练鉴别器网络的过程中设一个缓冲区buffer,用来保存之前训练过的旧的改善图像,缓冲区大小为B,b为鉴别器每次训练图片批量的大小。在鉴别器训练的每一次迭代中,使用b/2的当前改善网络所生成的图片和b/2的缓存区中的图片来更新鉴别器网络参数。训练过程始终保持缓冲区B的数量固定,每次训练结束后用随机的b/2的新生成的改善图片来替换缓存区的图片。
图6 Updating Discriminator using a History of Refined Images
如图7所示,本文也针对该改进方法进行了实验比对。左边是合成图像,中间是使用了阶段性的旧的改善图像进行训练所生成的合成图,右边的则是仅使用最新的改善器所产生的数据进行训练所生成的合成图。我们可以观察到右边的图像有明显的伪影,特别是在眼角周围。
图7 Using a history of refined images for updating the
discriminator
2
实验部分本文主要举了两个例子,将文中所提到的模型用于凝视方向估计和手势识别,他们的训练数据信息如表1所示。
凝视方向估计和手势识别的训练均存在着以下问题,
一、训练数据难以标记,数据昂贵而稀少。
二、两种训练数据均可以通过建模合成,但这些合成数据画面不够真实。
两种实验的实验分析均包括了定性和定量两种结果
表1
Experiment-Gaze Estimation
如图8所示,在凝视估计实验中,左边是真实图像,右边是合成图像及相应的改善输出图像。我们可以观察到合成图像经过SimGANs改善后相较原图更具真实感,也保持了原图的视觉凝视方向/标签信息。
图8 凝视估计实验定性分析结果
实验中作者也训练了一个CNN网络来预测凝视方向,量化了SimGANs以假乱真的能力。如图9所示,将SimGANs修正后的图像作为训练数据可以显著提高模型的预测能力。
图9 凝视估计实验定量分析结果
同时,论文作者也提供了最直观的模型预测错误率作为参考,如表2所示,将该方法的实验结果与在同一数据集上最先进的预测方法进行比对,我们也可以发现在SimGANs改善后图像上进行训练的模型的预测性能相较于其他是有显著提升的。
表2
3
Experiment-Gaze Estimation
另一个实验则是将论文中提到的改善方法应用于手势识别实验中,同样可以分为定性和定量两种分析结果。
图10展示了对应实验的定性结果,左边是真实图像,右边是合成图像及相应的改善输出图像。真实图像中噪声的主要来源是非光滑阴影边界。改善网络可以对真实图像中存在的噪声进行建模,不需要任何真实图像的标签,可以明显看出经过改善后的合成图像更具真实感。
图10 手势识别实验定性分析结果
同时作者实验中也建立了手势估计的CNN网络来进行性能量化评估,在图11中可以看到,在改善图像上进行训练的模型明显优于在真实图像下进行监督训练的模型,总体提高了8.8%。
以上便是具体的实验相关分析,实验中用的的卷积神经网络结构在论文中也有详细的描述,有兴趣的同学可以再仔细研读。
图11 手势识别实验定性分析结果
总结
本文通过对抗性训练减小图像分布之间的差距,使得合成图像更加逼真。训练过程的数据因为不需要进行事先标记,相对容易获取,获得的改善图像也可以作为训练数据用于其他模型的训练,这相当于数据扩增的一种方法。
本文主要具有以下贡献:
1、提出了一种使用未标记的真实数据来改善模拟器所生成的合成图像的S+U学习
2、训练一个改善网络,通过对抗性损失和自我正则化损失的组合,增加合成图像真实感并使修正后的图像能保留标注信息。
3、对GAN训练框架进行了几处关键的修改,稳定训练过程和控制改善的结果