GAN生成对抗网络基本概念及基于mnist数据集的代码实现

本文深入介绍了生成对抗网络(GAN)的基本原理,包括生成模型和判别模型的构建,并通过MNIST数据集展示了GAN的训练过程。在训练过程中,通过调整模型参数,生成模型试图生成逼真的数字图像,而判别模型则努力区分真实和伪造数据。经过训练,生成的图像已显示出可辨识的数字,如5、7、9。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用。

GAN网络是由两个目标相对立的网络构成的,在所有GAN框架中都至少包含了两个部分,生成模型部分和判别模型部分。生成模型的目标是制造出一些与真实数据十分相似的伪造数据而判别模型的目标则恰恰相反,是找到如何分辨这些真实的数据以及伪造数据。

下图可以用来比较简明地理解GAN的工作原理 :
在这里插入图片描述
生成模型的输入是随机的噪声编码 z z z,通过这个噪声生成的数据 G ( z ) G(z) G(z) 就是我们伪造出的数据了。判别模型的输入是一组混合了真实数据 x x x以及伪造的数据 G ( z ) G(z) G(z)的混合数据并输出 D ( G ( z ) ) D(G(z)) D(G(z)) 以及 D ( x ) D(x) D(x),代表了对真实数据和伪造数据的判定。如果我们把伪造数据的标签定为0,真实数据的标签定为1,那么判别模型的训练目标就是使 D ( G ( z ) ) D(G(z)) D(G(z))无限接近0,使 D ( x ) D(x) D(x)无限接近1,以此来达到分辨真实数据和伪造数据的目的。相反的,生成模型的训练目标则是要使得 D ( G ( z ) ) D(G(z)) D(G(z))接近1,即达到欺骗判别模型,以假乱真的目的。我们不难发现,实际生成模型的训练离不开判别模型的判定,而判别模型的训练也需要生成模型生成的伪造数据,二者相辅相成。这一点在下面基于mnist数据集的训练代码中也会有所体现。

首先是import所需库并导入mnist数据,我们通过全部除以255的方法正则化用于训练和测试的图像。

import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
from keras.layers import Dense, Conv2DTranspose, BatchNormalization, Reshape, LeakyReLU, Conv2D
import numpy as np

# load data from database mnist
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path="mnist.npz")

# normaliser à [0,1]
x_train = x_train/255.0
x_test = x_test/255.0

训练集x_train中包含了60000张 28*28的单通道灰度图片,每张图片对应标签为0-9的十个数字,下面展示其中一张代表了数字5的图片。
在这里插入图片描述

1. 生成模型部分

首先是我们的生成模型,如上所说,生成模型的输入为随机噪声,输出为伪造的数据。在这个例子中,我们最终要输出一张与真实图片大小一致的灰度图。生成模型由两类主要的层构成,其中之一就是全连接层dense,这个层实际上类似于CNN卷积神经网络中代表特征 feature 的一层,我们可以理解为它由数个低解析度的图像组成。之后需要的就是将解析度提升至的操作,这里用了Conv2DTranspose层,可以理解为是一个反向的pooling池化层(用于还原参数和数据)和一个2D卷积层的结合。Conv2DTranspose层中的参数stride设置为(2, 2) 即保证了每经过一次该层,输出的宽度和高度都扩大一倍。如下例所示,由7 * 7 经过两次Conv2DTranspose层使得最终输出的灰度图宽度和高度为 28 * 28 。生成模型的输出层是一个简单的2D卷积层,使用activation激励函数为sigmoid,这是由于sigmoid函数可以使得输出值属于[0, 1]的区间,也对应了我们在一开始在数据预处理的时候,将数据正则化至[0, 1]的操作。

# creation of a generator
def creation_generateur(dim_latent=10):
  generator = keras.models.Sequential()
  generator.add(Dense(128*7*7, input_dim=dim_latent))
  generator.add(Reshape((7,7,128)))
  # upsampling 
  generator.add(Conv2DTranspose(filters=128,kernel_size=(5,5),strides=(2,2),padding="same"))
  generator.add(LeakyReLU(alpha=0.2))
  # upsampling 
  generator.add(Conv2DTranspose(filters=128,kernel_size=(5,5),strides=(2,2),padding="same"))
  generator.add(LeakyReLU(alpha=0.2))
  generator.add(Conv2D(1, kernel_size=(7, 7), activation='sigmoid', padding="same"))
  return generator

2. 判别模型部分

接着我们创建GAN中的判别模型,相比生成模型而言,判别模型就更加简明,其实质就是一个classifier二元分类器。他由多个卷积层构成,其中添加了drop out用于防止过拟合。输出层是一个仅有一个神经元的全连接层,使用sigmoid作为激励函数。正如我们前文所提到的,判别模型会对输入进行分类,判别输入究竟是真实图像还是由生成模型伪造的图像。

# creation of a discriminator
def creation_discriminateur():
  discriminator 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值