1.关于GAN
GAN,全称为生成对抗网络,是一种深度学习模型,由Ian Goodfellow及其同事于2014年提出。 GAN由两个主要的神经网络组成——生成器和判别器。
- 生成器的任务是生成看起来像训练图像的“假”图像;
- 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。
GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习,从而产生了相当好的输出。
2.数据集
学习使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图片。
数据集下载:使用download
接口下载数据集,并将自动解压到指定目录下。
数据加载:使用MindSpore的MnistDatase
接口,读取和解析MNIST数据集的源文件构建数据集,进行必要的预处理。
数据集可视化:通过create_dict_iterator
函数将数据转换成字典迭代器,并使用matplotlib等工具,可视化部分训练数据以检查数据质量。
隐码构造:为了跟踪生成器的学习进度,在每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise
输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。
3.模型构建
生成器
通过五层 Dense
全连接层来完成的,每层都与 BatchNorm1d
批归一化层和 ReLU
激活层配对,输出数据会经过 Tanh
函数,使其返回 [-1,1] 的数据范围内。
判别器
通过一系列的 Dense
层和 LeakyReLU
层对其进行处理,最后通过 Sigmoid
激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。
损失函数和优化器
损失函数使用MindSpore中二进制交叉熵损失函数BCELoss
。
为生成器和判别器分别构建Adam优化器,分别用于更新两个模型的参数。
4.模型训练
训练分为两个主要部分:训练判别器和训练生成器。