GAN-MNIST实战


前言

使用生成对抗网络(GAN)生成手写数字的任务


一、GAN原理

基本原理是通过训练两个神经网络——生成器和判别器,并通过对抗学习的方式相互竞争和提高性能,从而生成看起来像真实样本的数据

二、使用步骤

1.引入库

代码如下(示例):

import os.path
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import MatplotlibDeprecationWarning
import warnings

2. matplotlib异常处理

用于取消警告显示和处理莫名其妙的报错:

warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning)

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

3. 数据加载

#数据预处理操作
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),#将图像数据转换为PyTorch张量(Tensor)的格式
    torchvision.transforms.Normalize(0.5,0.5)#对图像进行标准化处理
])
#加载数据
real_data = torchvision.datasets.MNIST(
    root='../FGSM-MNIST-master/data',
    train=True,
    transform=transform,#0-1; channel,high,width
    download=True
)
bs = 128
real_dataloader = DataLoader(
    dataset=real_data,
    batch_size=bs,#批次大小
    shuffle=True#随机打乱
)
imgs,_=next(iter(real_dataloader))#获取图像示例
print(imgs.shape)#输出一个批次图像的数据形状 torch.Size([128, 1, 28, 28])

4. 定义生成器

简化实验,使用三层全连接网络实现

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(in_features=100, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=28*28),
            nn.Tanh() #将输入值映射到-1, 1之间
      )

    def forward(self, X):
        return self.network(X).view(-1,28,28,1)##修改形状,-1自动推理,28*28 1通道

5. 定义辨别器

与传统的ReLU函数相比,nn.LeakyReLU()能够更好地处理梯度消失的问题,因为它在负值区域有一个较小的斜率,防止了神经元"死亡"并促进了梯度向后传播。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.network=nn.Sequential(
            nn.Linear(28*28,512),
            #nn.LeakyReLU f(x) :x>0 输出0,如果x<0,输出 a*x a表示一个很小的斜率,如0.1
            nn.LeakyReLU(),
            nn.Linear(512,256),
            nn.LeakyReLU(),
            nn.Linear(256,1),
            nn.Sigmoid()#转为概率值
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        return self.network(x)

6. 训练迁移/设置优化器

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)

d_optim = torch.optim.Adam(dis.parameters(),lr=2e-4)#辨别优化器
g_optim = torch.optim.Adam(gen.parameters(),lr=2e-4)#生成优化器
loss_fn = torch.nn.BCELoss()#二元交叉熵损失函数

7. 绘制训练图像

def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(prediction.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i]+1)/2)  # 0~1之间
        plt.axis('off')
    plt.show()

8. 循环训练

D_loss=[]
G_loss=[]
count = len(real_dataloader)

#训练循环
for epoch in range(20):
    print(f"epoch {epoch + 1}\n-----------------")


    d_epoch_loss=0
    g_epoch_loss = 0

    for i, (X_real, _) in enumerate(real_dataloader):
        ###鉴别器
        X_real = X_real.to(device)

        size = X_real.size(0)

        random_noise=torch.randn(size, 100, device=device)#构造图片

        d_optim.zero_grad()#避免梯度累加

        real_output = dis(X_real)   #判别器输入真实图片,real_output对真实图片的判别结果

        d_real_loss = loss_fn(real_output,
                              torch.ones_like(real_output)) #1.判别器在真实图像的损失

        d_real_loss.backward()#计算真实图像损失


        gen_img = gen(random_noise) #生成假图片

        fake_output = dis(gen_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测 优化判别器#截断梯度,不利用生成器训练辨别器

        d_fake_loss = loss_fn(fake_output,
                              torch.zeros_like(fake_output)
                            )#2.得到判别器在生成图像上的损失

        d_fake_loss.backward()#计算生成图像的损失

        d_loss= d_real_loss+d_fake_loss

        d_optim.step() #参数更新操作

        ###生成器

        g_optim.zero_grad()#避免梯度累加

        fake_output=dis(gen_img)#需要记录梯度

        g_loss=loss_fn(fake_output,
                       torch.ones_like(fake_output)
                   ) #生成器损失

        g_loss.backward()
        g_optim.step() #参数更新操作
        with torch.no_grad():#不需要进行反向传播或计算梯度,只对损失进行累加
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

        if i % 100 == 0:
            print(
                f"loss_G: {d_loss.item()}, loss_D: {g_loss.item()}, D(x): {d_real_loss.item()}, D(G(z)): {d_fake_loss.item()}")
    with torch.no_grad():
            d_epoch_loss /=count
            g_epoch_loss /=count
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)

            gen_img_plot(gen, test_input)

torch.save(gen.state_dict(), 'model_my.pth')

三、训练过程示例

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述


  • 7
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: 使用TensorFlow来训练并测试手写数字识别的MNIST数据集十分简单。首先,我们需要导入TensorFlow和MNIST数据集: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data 接下来,我们可以使用input_data.read_data_sets()函数加载MNIST数据集,其中参数为下载数据集的路径。我们可以将数据集分为训练集、验证集和测试集。这里我们将验证集作为模型的参数调整过程,测试集用于最终模型评估。 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 接下来,我们可以使用TensorFlow创建一个简单的深度学习模型。首先,我们创建一个输入占位符,用于输入样本和标签。由于MNIST数据集是28x28的图像,我们将其展平为一个784维的向量。 x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) 接下来,我们可以定义一个简单的全连接神经网络,包含一个隐藏层和一个输出层。我们使用ReLU激活函数,并使用交叉熵作为损失函数。 hidden_layer = tf.layers.dense(x, 128, activation=tf.nn.relu) output_layer = tf.layers.dense(hidden_layer, 10, activation=None, name="output") cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output_layer, labels=y)) 然后,我们可以使用梯度下降优化器来最小化损失函数,并定义正确预测的准确率。这样就完成了模型的构建。 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(output_layer, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 接下来,我们可以在一个会话中运行模型。在每次迭代中,我们从训练集中随机选择一批样本进行训练。在验证集上进行模型的参数调整过程,最后在测试集上评估模型的准确率。 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(1000): batch_x, batch_y = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_x, y: batch_y}) val_accuracy = sess.run(accuracy, feed_dict={x: mnist.validation.images, y: mnist.validation.labels}) print("Validation Accuracy:", val_accuracy) test_accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}) print("Test Accuracy:", test_accuracy) 通过这个简单的代码,我们可以使用TensorFlow训练并测试MNIST数据集,并得到测试集上的准确率。 ### 回答2: gan tensorflow mnist是指使用TensorFlow框架训练生成对抗网络GAN)来生成手写数字图像的任务。 首先,手写数字数据集是一个非常常见且经典的机器学习数据集。MNIST数据集包含了由0到9之间的手写数字的图像样本。在gan tensorflow mnist任务中,我们的目标是使用GAN来生成与这些手写数字样本类似的新图像。 GAN是一种由生成器和判别器组成的模型。生成器任务是生成看起来真实的图像,而判别器任务是判断给定图像是真实的(来自训练数据集)还是生成的(来自生成器)。这两个模型通过对抗训练来相互竞争和提高性能。 在gan tensorflow mnist任务中,我们首先需要准备和加载MNIST数据集。利用TensorFlow的函数和工具,我们可以轻松地加载和处理这些图像。 接下来,我们定义生成器和判别器模型。生成器模型通常由一系列的卷积、反卷积和激活函数层组成,以逐渐生成高质量的图像。判别器模型则类似于一个二分类器,它接收图像作为输入并输出真实或生成的预测结果。 我们使用TensorFlow的优化器和损失函数定义GAN模型的训练过程。生成器的目标是误导判别器,使其将生成的图像误认为是真实图像,从而最大限度地降低判别器的损失函数。判别器的目标是准确地区分真实和生成的图像,从而最大限度地降低自身的损失函数。 最后,我们使用训练数据集来训练GAN模型。通过多次迭代,生成器和判别器的性能会随着时间的推移而得到改善。一旦训练完成,我们可以使用生成器模型来生成新的手写数字图像。 总结来说,gan tensorflow mnist是指使用TensorFlow框架训练生成对抗网络来生成手写数字图像的任务。通过定义生成器和判别器模型,使用优化器和损失函数进行训练,我们可以生成类似于MNIST数据集手写数字的新图像。 ### 回答3: 用TensorFlow训练MNIST数据集可以实现手写数字的分类任务。首先我们需要导入相关库和模块,如tensorflow、keras以及MNIST数据集。接着,我们定义模型的网络结构,可以选择卷积神经网络(CNN)或者全连接神经网络(DNN)。对于MNIST数据集,我们可以选择使用CNN,因为它能更好地处理图像数据。 通过调用Keras中的Sequential模型来定义网络结构,可以添加多个层(如卷积层、池化层、全连接层等),用来提取特征和做出分类。其中,输入层的大小与MNIST图片的大小相对应,输出层的大小等于类别的数量(即0~9的数字)。同时,我们可以选择优化器(如Adam)、损失函数(如交叉熵)和评估指标(如准确率)。 接下来,我们用模型编译来配置模型的学习过程。在编译时,我们可以设置优化器、损失函数和评估指标。然后,我们用训练数据对模型进行拟合,通过迭代优化来调整模型的权重和偏置。迭代次数可以根据需要进行调整,以达到训练效果的需求。 训练结束后,我们可以使用测试数据对模型进行评估,获得模型在测试集上的准确率。最后,我们可以使用模型对新的未知数据进行预测,得到相应的分类结果。 综上所述,使用TensorFlow训练MNIST数据集可以实现手写数字的分类任务,通过定义模型结构、编译模型、拟合模型、评估模型和预测来完成整个过程。这个过程需要一定的编程知识和理解深度学习的原理,但TensorFlow提供了方便的api和文档,使我们能够相对容易地实现这个任务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值