此示例说明如何训练生成对抗网络(Generative Adversarial Network)来生成图像;
一:原理
GAN:一种深度学习网络,能够生成与真实输入数据具有相似特征的数据;
GAN由两个一起训练的网络组成–生成器(Generator)和判别器(Discriminator):
- 生成器(Generator):给定随机值(潜在输入)向量(通常是满足某一特征分布的任意向量,即GAN结构图中的Noise), Generator可以生成(输出)与训练样本具有相同数据特征(结构/分布)的数据(Generated Images);
- 判别器(Discriminator):给定包含训练数据(Real Images)和Generator生成的生成数据(Generated Images),.Discriminator尽量将训练数据和生成数据分别判定为"真实值:“和"生成值”,即输出Predicted Labels;
训练GAN即同时训练Generator和Dsicrimiator,分别最大化两个网络的性能:
- 训练Generator生成"欺骗"Discriminator的数据(Generated Images),即判别器对生生成数据判别为"真实(1)“;
2)训练Discriminator"区分"真实数据"和"生成数据”,即判定真实数据为"1",生成数据为"0";
优化生成器的性能,即给定生成的数据,最大化判别器的损失(最小化生成器的损失),生成器的优化目标是生成判别器判别为真(1)的数据;
优化判别器的性能,即给定真实数据和生成数据,最小化判别器的损失,判别器的优化目标是区分生成数据和真实数据,判别真实数据为1,判定生成数据为0;
理想情况下,训练好的生成器能够生成与真实样本同分布的数据,训练好的判别器能够学习到训练数据特有的强特征表示;
训练GAN的过程如下:
训练数据的获取, 生成器和判别器网络的搭建, 模型梯度,损失函数,以及生成器和判别器分数的定义, GAN模型的训练过程, 以及如何基于训练好的生成器生成新的图像, 都基于以下的MATLAB代码:
https://ww2.mathworks.cn/help/deeplearning/ug/train-generative-adversarial-network.html
(MATLAB-R2020b以后的版本可以训练GAN)