GAN同时要训练一个生成网络(Generator)和一个判别网络(Discriminator),前者输入一个noise变量
z
,输出一个伪图片数据
根据上述训练过程的描述,我们可以定义一个损失函数:
Loss=1m∑mi=1[logD(xi)+log(1−D(G(zi)))]
其中
xi
,
zi
分别是真实的图片数据以及noise变量。
而优化目标则是:
minGmaxDLoss
不过需要注意的一点是,实际训练过程中并不是直接在上述优化目标上对 θd , θg 计算梯度,而是分成几个步骤:
训练判别器即更新
θd
:循环
k
次,每次准备一组real image数据
∇θd1m∑mi=1[logD(xi)+log(1−D(G(zi)))]
然后梯度上升法更新
θd
;
训练生成器即更新
θg
:准备一组fake image数据
z=z1,z2,⋯,zm
,计算
∇θg1m∑mi=1log(1−D(G(zi)))
然后梯度下降法更新
θg
。
可以看出,第一步内部有一个
k
<script type="math/tex" id="MathJax-Element-2008">k</script>层的循环,某种程度上可以认为是因为我们的训练首先要保证判别器足够好然后才能开始训练生成器,否则对应的生成器也没有什么作用,然后第二步求提督时只计算fake image那部分数据,这是因为real image不由生成器产生,因此对应的梯度为0。