一、模型简介
- GAN 原理
- GAN 由生成器(Generative Model)和判别器(Discriminative Model)组成。
- 生成器试图生成逼真的假图像来欺骗判别器,判别器则要准确区分真实图像和生成器生成的假图像。
- 损失函数:
min_G max_D V(D, G) = E_{x∼p_{data}(x)}[logD(x)] + E_{z∼p_z(z)}[log(1 - D(G(z)))]
- 其中,
D(x)
表示判别器判定图像x
为真实图像的概率。 G(z)
表示生成器将隐码z
映射到数据空间生成的图像。
- 其中,
- 推理流程
- 输入隐码
z
到生成器。 - 生成器生成假图像
G(z)
。 - 判别器判断
G(z)
的真假。
- 输入隐码
二、数据集
- 数据集介绍
- MNIST 手写数字数据集,包含 70000 张手写数字图片,已进行尺寸归一化和中心化处理。
- 数据下载
- 函数:
download(url, path, kind, replace)
url
:数据集下载链接。path
:保存路径。kind
:文件类型,如zip
。replace
:是否替换已存在文件。- 示例:
download("https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip", ".", kind="zip", replace=True)
- 函数:
- 数据加载与处理
- 函数:
MnistDataset(dataset_dir)
dataset_dir
:数据集目录。- 示例:
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
- 数据增强与预处理函数:
map
、batch
等。
- 函数:
三、隐码构造
- 函数:np.random.seed(seed)
,设置随机数种子,确保结果可复现。
- 示例:np.random.seed(2323)
- 构造隐码:test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
四、模型构建
- 生成器
- 类:
Generator
- 构造函数:
__init__(self, latent_size, auto_prefix=True)
latent_size
:隐码长度。auto_prefix
:自动添加参数名前缀,默认为True
。
- 示例:
net_g = Generator(latent_size)
- 构造函数:
- 生成器结构:由多层
Dense
全连接层、BatchNorm1d
批归一化层、ReLU
激活层和Tanh
激活层组成。
- 类:
- 判别器
- 类:
Discriminator
- 构造函数:
__init__(self, auto_prefix=True)
auto_prefix
:自动添加参数名前缀,默认为True
。
- 示例:
net_d = Discriminator()
- 构造函数:
- 判别器结构:由多层
Dense
全连接层和LeakyReLU
激活层组成,最后通过Sigmoid
激活函数输出概率。
- 类:
五、损失函数和优化器
- 损失函数
- 函数:
BCELoss(reduction='mean')
reduction
:损失计算方式,如'mean'
计算均值。- 示例:
adversarial_loss = nn.BCELoss(reduction='mean')
- 函数:
- 优化器
- 函数:
Adam(params, learning_rate, beta1, beta2)
params
:可训练参数。learning_rate
:学习率。beta1
、beta2
:优化器参数。- 示例:
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
- 函数:
六、模型训练
- 训练过程函数
generator_forward(test_noises)
:计算生成器的损失。discriminator_forward(real_data, test_noises)
:计算判别器的损失。train_step(real_data, latent_code)
:进行一步训练,计算判别器和生成器的损失并更新参数。save_imgs(gen_imgs1, idx)
:保存生成的测试图像。
- 训练流程
- 迭代训练周期。
- 在每个周期内,遍历数据集中的批次。
- 计算判别器和生成器的损失,并更新参数。
- 保存生成的图像和模型参数。
七、效果展示
- 绘制损失曲线,观察判别器和生成器损失随训练迭代的变化。
- 生成动态图展示训练过程中生成的图像质量变化。
八、模型推理
- 加载生成器的模型参数。
- 输入随机隐码生成图像并展示。
九、操作流程
- 准备数据:下载 MNIST 数据集,进行加载、增强和预处理。
- 构建模型:定义生成器和判别器的结构。
- 定义损失函数和优化器。
- 训练模型:迭代训练,计算损失,更新参数,保存中间结果。
- 效果评估:查看损失曲线,观察生成图像质量。
- 模型推理:加载模型参数,生成新图像。
十、调用库及功能
mindspore
:构建模型、计算梯度、优化参数等。nn
模块:提供神经网络层和损失函数。ops
模块:操作符和函数。
mindspore.dataset
:加载和处理数据集。numpy
:数据处理和操作。matplotlib
:数据可视化。