写在前面:最近看了《GAN实战》,由于本人忘性大,所以仅是笔记而已,方便回忆,如果能帮助大家就更好了。
生成器 | 鉴别器 | |
---|---|---|
输入 | 一个随机数向量(噪声) | 来自训练集的真实样本,来自生成器的伪样本 |
输出 | 尽可能令人信服的伪样本 | 预测输入样本是真实的概率 |
目标 | 生成与训练集中数据无差异的伪数据 | 区分来自生成器的伪样本和来自训练集的真实样本 |
GAN训练过程:
(1)训练鉴别器
a.从训练集中随机抽取真实样本x。
b.获取一个新的随机噪声向量z,用生成器网络合成一个伪样本x*。
c.用鉴别器网络对x和x*进行分类
d.计算分类误差并反向传播总误差以更新鉴别器的可训练参数,寻求最小化分类误差。
(2)训练生成器
a.获取一个新的随机噪声向量z,用生成器网络合成一个伪样本x*。
b.用鉴别器网络对x和x*进行分类
c.计算分类误差并反向传播总误差以更新鉴别器的可训练参数,寻求最大化鉴别器误差。
当整个网络达到纳什均衡时停止:
(1)生成器生成的伪样本与训练集中真实数据别无二致。
(2)鉴别器所能做的只是一个随机猜测样本是真是假(50%概率)
当达到纳什均衡后,认为GAN达到收敛。