GAN简单介绍—使用PyTorch框架搭建GAN对MNIST数据集进行训练

GAN 的简单介绍

GAN,全称Generative Adversarial Networks,即生成对抗网络,是深度学习中一种强大的生成模型。GAN 是由 Ian Goodfellow 等人在 2014 年提出的,通过让两个神经网络相互对抗来生成新的数据。

GAN(Generative Adversarial Networks,生成对抗网络)是一种深度学习框架,用于生成新的、与训练数据相似的数据。GAN由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络在训练过程中相互竞争和协作,使得生成器能够生成越来越逼真的数据。

本文将介绍如何使用 PyTorch 实现 GAN。

GAN 介绍

GAN 的工作原理是,生成器(Generator)和判别器(Discriminator)两个神经网络相互对抗地学习。生成器用于生成图像或数据,而判别器则用于判断输入的数据是否真实。两个神经网络不断地交替训练,直到生成器可以生成接近于真实样本的样本。

GAN 可以用于生成各类数据,如图像、音频、文本等。在图像生成方面,GAN 用于生成无限多张与训练数据相似却并不存在于训练数据中的图像。GAN 可以应用于许多领域,如计算机图形学、自然语言处理等等。

实现 GAN

加载数据集

使用 PyTorch 中的 MNIST 数据集进行训练。MNIST 数据集包含手写数字图像,大小为 28x28 像素。

import torch
from torchvision import datasets, transforms

# 使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载 MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(mnist_data, batch_size=128, shuffle=True)

定义生成器和判别器

生成器和判别器都是神经网络模型。生成器将从随机噪声中生成图像,而判别器将对输入图像进行分类,判断它是真实图像还是生成器生成的伪造图像。

import torch.nn as nn

# 定义生成器
class Generator(nn.Module):
    def __init__(self, z_dim=100, hidden_dim=128):
        super(Generator, self).__init__() # 继承 nn.Module 类并初始化
        self.z_dim = z_dim # 输入向量的维度
        self.hidden_dim = hidden_dim # 隐藏层维度
        self.fc1 = nn.Linear(z_dim, hidden_dim) # 全连接层,输入为 z_dim 维,输出为 hidden_dim 维
        self.fc2 = nn.Linear(hidden_dim, 28 * 28) # 全连接层,将隐藏层映射到 28*28 的图像

    def forward(self, z):
        x = self.fc1(z) # 输入 z,通过全连接层 fc1 得到隐藏层向量 x
        x = nn.functional.leaky_relu(x, 0.2) # 在隐藏层中应用 LeakyReLU 激活函数
        x = self.fc2(x) # 将隐藏层映射到生成的图像
        x = nn.functional.tanh(x) # 将输出值映射到 [-1, 1] 的范围
        return x.view(-1, 1, 28, 28) # 将输出展平成图片张量形式

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, hidden_dim=128):
        super(Discriminator, self).__init__() # 继承 nn.Module 类并初始化
        self.hidden_dim = hidden_dim # 隐藏层维度
        self.fc1 = nn.Linear(28 * 28, hidden_dim) # 输入为 28*28 的图像,输出为隐藏层向量
        self.fc2 = nn.Linear(hidden_dim, 1) # 只有一个输出,表示输入是否是真实的图像。

    def forward(self, x):
        x = x.view(-1, 28*28) # 将输入展平为向量
        x = self.fc1(x) # 将展平后的向量通过全连接层 fc1 映射到隐藏层
        x = nn.functional.leaky_relu(x, 0.2) # 在隐藏层中应用 LeakyReLU 激活函数
        x = self.fc2(x) # 将隐藏层输出映射到单个输出值
        x = nn.functional.sigmoid(x) # 将输出值映射到 [0, 1] 的范围,表示输入是否是真实图像的概率
        return x

这段代码实现了一个基本的生成对抗网络(GAN)。GAN 由两个神经网络组成,一个生成器(Generator)和一个判别器(Discriminator),它们在博弈中对抗地进行训练。生成器从随机噪声中生成假图像,而判别器则试图区分真实图像和生成图像。训练过程目的是使得生成器生成的图像更逼真,从而欺骗判别器,使其将生成的图像和真实图像分类错误。以下是代码实现的具体步骤:

  1. 导入 torch.nn 库,以使用 PyTorch 提供的深度学习模型库;
  2. 定义生成器:继承 nn.Module 类并初始化输入随机噪声向量的维度(z_dim)和隐藏层维度(hidden_dim),以及两个全连接层 fc1 和 fc2。在 forward 方法中,通过全连接层将输入向量 z 映射成隐藏层向量 x,然后在隐藏层中应用 leaky ReLU 激活函数,将隐藏层映射到输出图像,最后将输出值映射至 [-1, 1] 的范围并将其展平成图片张量形式;
  3. 定义判别器:同样继承 nn.Module 类并初始化隐藏层维度(hidden_dim)和两个全连接层 fc1 和 fc2。在 forward 方法中,将输入的二维图片张量展平为向量 x,通过全连接层 fc1 映射到隐藏层,应用 leaky ReLU 激活函数,最后将隐藏层输出通过全连接层 fc2 映射为单个输出值,并将其映射至 [0, 1] 的范围,表示输入是否是真实图像的概率;
  4. GAN 的训练过程:定义生成器和判别器的优化器(optimizer),设置目标损失函数(loss function),以及迭代训练的循环。每轮迭代都以真实图像或噪声输入生成器作为输入,计算生成器对应的生成图像,将真实的和生成的数据放入判别器中分别计算输出,分别计算判别器对真实数据的损失和对生成数据的损失,然后利用这两个损失更新生成器和判别器的参数。最后的目标是使得生成器生成尽可能逼真的数据,欺骗判别器,从而将其指导生成尽可能逼真的数据。

定义超参数、优化器、损失函数

定义超参数、优化器、损失函数。

# 定义超参数
lr = 0.0002
z_dim = 100
num_epochs = 50

generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

# 定义优化器和损失函数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)

criterion = nn.BCELoss()

训练网络

# 训练网络,使用给定的 epoch 数量
for epoch in range(num_epochs):
    # 遍历数据集的 mini-batches
    for i, (real_images, _) in enumerate(data_loader):
        # 将真实图像传递给设备
        real_images = real_images.to(device)

        # 定义真实标签为 1,假的标签为 0
        real_labels = torch.ones(real_images.size(0), 1).to(device)
        fake_labels = torch.zeros(real_images.size(0), 1).to(device)

        # 训练生成器
        # 生成随机的噪音 z
        z = torch.randn(real_images.size(0), z_dim).to(device)

        # 生成 fake_images
        fake_images = generator(z)

        # 将 fake_images 传递给判别器,得到输出 fake_output
        fake_output = discriminator(fake_images)

        # 计算生成器的损失值 loss_G
        # 需要将 fake_output 与真实标签 real_labels 进行比较
        loss_G = criterion(fake_output, real_labels)

        # 重置 generator 的优化器,清除梯度
        optimizer_G.zero_grad()

        # 计算 generator 的梯度
        loss_G.backward()

        # 更新 generator 参数,使用优化器 optimizer_G
        optimizer_G.step()

        # 训练判别器
        # 将真实图像 real_images 通过判别器得到输出 real_output
        real_output = discriminator(real_images)

        # 计算判别器对真实图像的损失值 loss_D_real
        # 需要将 real_output 与真实标签 real_labels 进行比较
        loss_D_real = criterion(real_output, real_labels)

        # 计算判别器对生成器生成的 fake_images 的损失值 loss_D_fake
        # 需要将 fake_output 与假的标签 fake_labels 进行比较
        fake_output = discriminator(fake_images.detach())
        loss_D_fake = criterion(fake_output, fake_labels)

        # 计算判别器总损失值 loss_D,即 loss_D_real 和 loss_D_fake 的和
        loss_D = loss_D_real + loss_D_fake

        # 重置 discriminator 的优化器,清除梯度
        optimizer_D.zero_grad()

        # 计算 discriminator 的梯度
        loss_D.backward()

        # 更新 discriminator 参数,使用优化器 optimizer_D
        optimizer_D.step()

        # 输出损失值
        # 每100次迭代输出一次
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
                  .format(epoch+1, num_epochs, i+1, len(data_loader), loss_D.item(), loss_G.item()))

这段代码表示了一个完整的 GAN 利用交替梯度下降算法进行训练的过程。在每个 epoch 内,我们遍历数据集中的所有 mini-batches,并对每个 mini-batch 迭代进行以下步骤:

  1. 将真实图片 real_images 传递给设备。
  2. 定义标签,其中真实图像对应的标签为 1,生成器生成的假图像对应的标签为 0。
  3. 训练生成器。从一个经过随机初始化的噪声向量 z 开始,使用生成器学习生成伪造图片。将生成的图片 fake_images 传递给判别器,并使用判别器输出来计算生成器的损失值。这个损失值反映了生成器生成的图片与真实标签之间的差异。
  4. 重置生成器的优化器,并清空梯度。
  5. 对生成器的损失值进行反向传播计算梯度。
  6. 使用优化器来更新生成器的参数,使其在下一个迭代中能够更好地生成伪造图片。
  7. 训练判别器。使用真实图片来计算判别器对真实图片的损失值,用假图片来计算判别器对伪造图片的损失值,并将这些值相加以得到判别器的总损失。使用反向传播来计算梯度并更新判别器参数。
  8. 每 100 次迭代输出一次损失值,以跟踪训练的进展。
  9. 循环在所有训练数据上结束时,一次完整的迭代称为一个 epoch。重复执行此步骤 num_epochs 次来完成 GAN 的训练。

使用生成器生成新图像

在训练结束后,可以使用生成器来生成新图像。

# 生成一些测试数据
# 随机生成一些长度为 z_dim 的、位于设备上的向量 z
z = torch.randn(16, z_dim).to(device)

# 使用生成器从 z 中生成一些假的图片
fake_images = generator(z).detach().cpu()

# 显示生成的图像
# 创建一个图形对象,大小为 4x4 英寸
fig = plt.figure(figsize=(4, 4))

# 在图形对象中创建4x4的网格,以显示输出的16张假图像
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(fake_images[i][0], cmap='gray')
    plt.axis('off')

# 显示绘制的图形
plt.show()

这段代码的目的是在训练 GAN 结束后,使用生成器生成一些随机的假图像进行展示。下面是具体步骤:

  1. 随机生成一些长度为 z_dim 的、位于设备上的向量 z
  2. 使用生成器从 z 中生成一些假的图片。
  3. 创建一个 4x4 英寸大小的图形对象。
  4. 在图形对象中创建一个 4x4 的网格,用于显示生成的 16 张图片。
  5. 使用 plt.imshow() 函数将图像附加到当前子图中。该函数会在 matplotlib 窗口中显示生成的图像。
  6. 从绘图中删除坐标轴。
  7. 显示生成的图形。

结论

GAN 是一种强大的生成模型,可以用于生成各种不同类型的数据。本教程展示了如何使用 PyTorch 实现 GAN,并使用 MNIST 数据集进行训练。

  • 5
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
使用PyTorch GAN训练自己的数据集,你需要进行以下步骤: 1. 准备数据集:首先,你需要准备你自己的数据集。确保你的数据集符合PyTorch的要求,每个样本都是一个Tensor类型的图像,并且尺寸一致。 2. 创建数据加载器:使用PyTorch的DataLoader类创建一个数据加载器,可以帮助你在训练过程中有效地加载和处理数据。你可以指定批量大小、数据的随机顺序等参数。 3. 定义生成器和判别器模型:根据你的数据集,定义生成器和判别器的模型。生成器模型将一个随机噪声向量作为输入,并生成一个与数据集相似的图像。判别器模型将图像作为输入,并输出一个值,表示该图像是真实图像还是生成图像。 4. 定义损失函数和优化器:为生成器和判别器定义适当的损失函数,通常是二分类交叉熵损失。然后,为每个模型创建一个优化器,例如Adam优化器。 5. 训练GAN模型:使用循环迭代的方式,在每个epoch中遍历数据集的所有mini-batches,并根据GAN训练的过程进行以下步骤:先训练生成器,传递真实图像和生成的假图像给判别器,并计算生成器的损失。然后,训练判别器,计算判别器对真实图像和生成的假图像的损失,并更新判别器的参数。重复这个过程,直到完成所有的epochs。 6. 生成新图像:训练完成后,你可以使用生成器模型生成新的图像。只需要提供一个随机噪声向量作为输入,通过生成器模型生成对应的图像。 请注意,这只是一个大致的概述,具体的实现细节会根据你的数据集GAN模型的架构而有所不同。你需要根据你的需求进行相应的调整和优化。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [GAN简单介绍使用PyTorch框架搭建GANMNIST数据集进行训练](https://blog.csdn.net/qq_36693723/article/details/130332573)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

百年孤独百年

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值