生成对抗网络
监督式学习是指基于大量带有标签的训练集与测试集的机器学习过程,而非监督式学习可以自己从错误中进行学习并降低未来出错的概率。前者的缺点是需要大量标签样本且非常耗时耗力,而后者虽然没有这个问题,但准确率往往比前者低。Goodfellow等人于2014年提出生成对抗网络(GAN),它提供了一种不需要大量标注训练数据就能学习深度表征的方式,可以通过反射传播算法分别更新生成器和判别器以执行竞争性学习而达到训练目的。GAN是一种全新的非监督式的架构,包含一个生成器和一个判别器。这种对抗训练过程与传统神经网络存在一个重要区别,具体的是传统神经网络需要有一个用来评估网络性能如何的成本函数,对抗网络不需要精心设计和构建成本函数而学习自己的成本函数。GAN所学习的表征可被用于图像合成、语义图像编辑、风格迁移、图像超分辨率技术和图像分类等多种应用。
生成器和判别器通常由包含卷积和(或)全连接层的多层网络构成。生成模型能够通过输入的样本产生可能的输出,学习到数据集的概率分布。
假设真实数据集为
x
x
x,为了得到生成模型对于真实数据
x
x
x的概率分布
p
g
p_g
pg,定义将由先验概率分布为
p
z
(
z
)
p_z(z)
pz(z)的噪声所构成的初始数据集映射到
x
x
x所在数据空间的深度网络结构为生成模型
G
(
z
;
θ
g
)
G(z;\theta_g)
G(z;θg)。与此同时定义另一种多层感知机为判别模型
D
(
x
;
θ
)
D(x;\theta)
D(x;θ)。另一方面,判别模型
D
D
D的输入端有两类数据,分别为初始的真实数据
x
x
x和由
G
G
G生成的伪造数据,
D
D
D将输入的承载丰富信息的高维数据转化为类别标签1或是0,即区分真实数据与伪造数据。通过同时训练生成模型
G
G
G与判别模型
D
D
D直至它们达到纳什平衡,使损失函数值
V
(
G
,
D
)
V(G,D)
V(G,D)最小,此时判别模型
D
D
D已无法区分真实数据与伪造数据之间的区别,而生成模型
G
G
G也获得了输入数据的最优概率分布。
min
G
max
D
V
(
D
,
G
)
=
E
x
p
real
(
x
)
[
log
D
(
x
)
]
+
E
Z
p
z
(
z
)
[
log
(
−
D
(
G
(
z
)
)
)
]
\min_G \max_D V(D,G)=\mathbb{E}_{x~p_{\text{real}}(x)}[\log D(x)]+\mathbb{E}_{Z~p_z(z)}\left[\log\left(-D\left(G(z)\right)\right)\right]
GminDmaxV(D,G)=Ex preal(x)[logD(x)]+EZ pz(z)[log(−D(G(z)))]
在训练过程中,
D
D
D会接收真实数据和
G
G
G产生的虚假数据,它的任务是判断图片是属于真数据的还是假数据的。根据最后输出的结果,可以同时对两方的参数进行调优。如果
D
D
D判断正确,那就需要调整
G
G
G的参数从而使得产生的伪造数据更为逼真;如果
D
D
D判断错误,则需调节
D
D
D的参数,避免下次类似判断出错。训练会一直持续到两者进入到纳什平衡状态,这样本的网络结构特性使得对抗学习模型具有更好的泛化能力。纳什平衡是指一种非合作的博弈策略组合,它假定有一个由
n
n
n个人构成的策略组合,如果这
n
n
n个人各自的策略都是为了达到自己期望收益的最大值,同时不受他人策略的影响时,达到的平衡状态。注意,在生成对抗网络模型的训练中一般采用交替优化的方法。生成对抗学习一般涉及多个神经网络模型和多个损失函数,大大增加了模型整体的泛化能力,使其在没有参加过训练的数据上也能表现良好。