GAN 的简单介绍
GAN,全称Generative Adversarial Networks,即生成对抗网络,是深度学习中一种强大的生成模型。GAN 是由 Ian Goodfellow 等人在 2014 年提出的,通过让两个神经网络相互对抗来生成新的数据。
GAN(Generative Adversarial Networks,生成对抗网络)是一种深度学习框架,用于生成新的、与训练数据相似的数据。GAN由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络在训练过程中相互竞争和协作,使得生成器能够生成越来越逼真的数据。
本文将介绍如何使用 PyTorch 实现 GAN。
GAN 介绍
GAN 的工作原理是,生成器(Generator)和判别器(Discriminator)两个神经网络相互对抗地学习。生成器用于生成图像或数据,而判别器则用于判断输入的数据是否真实。两个神经网络不断地交替训练,直到生成器可以生成接近于真实样本的样本。
GAN 可以用于生成各类数据,如图像、音频、文本等。在图像生成方面,GAN 用于生成无限多张与训练数据相似却并不存在于训练数据中的图像。GAN 可以应用于许多领域,如计算机图形学、自然语言处理等等。
实现 GAN
加载数据集
使用 PyTorch 中的 MNIST 数据集进行训练。MNIST 数据集包含手写数字图像,大小为 28x28 像素。
import torch
from torchvision import datasets, transforms
# 使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载 MNIST 数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(mnist_data, batch_size=128, shuffle=True)
定义生成器和判别器
生成器和判别器都是神经网络模型。生成器将从随机噪声中生成图像,而判别器将对输入图像进行分类,判断它是真实图像还是生成器生成的伪造图像。
import torch.nn as nn
# 定义生成器
class Generator(nn.Module):
def __init__(self, z_dim=100, hidden_dim=128):
super(Generator, self).__init__() # 继承 nn.Module 类并初始化
self.z_dim = z_dim # 输入向量的维度
self.hidden_dim = hidden_dim # 隐藏层维度
self.fc1 = nn.Linear(z_dim, hidden_dim) # 全连接层,输入为 z_dim 维,输出为 hidden_dim 维
self.fc2 = nn.Linear(hidden_dim, 28 * 28) # 全连接层,将隐藏层映射到 28*28 的图像
def forward(self, z):
x = self.fc1(z) # 输入 z,通过全连接层 fc1 得到隐藏层向量 x
x = nn.functional.leaky_relu(x, 0.2) # 在隐藏层中应用 LeakyReLU 激活函数
x = self.fc2(x) # 将隐藏层映射到生成的图像
x = nn.functional.tanh(x) # 将输出值映射到 [-1, 1] 的范围
return x.view(-1, 1, 28, 28) # 将输出展平成图片张量形式
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, hidden_dim=128):
super(Discriminator, self).__init__() # 继承 nn.Module 类并初始化
self.hidden_dim = hidden_dim # 隐藏层维度
self.fc1 = nn.Linear(28 * 28, hidden_dim) # 输入为 28*28 的图像,输出为隐藏层向量
self.fc2 = nn.Linear(hidden_dim, 1) # 只有一个输出,表示输入是否是真实的图像。
def forward(self, x):
x = x.view(-1, 28*28) # 将输入展平为向量
x = self.fc1(x) # 将展平后的向量通过全连接层 fc1 映射到隐藏层
x = nn.functional.leaky_relu(x, 0.2) # 在隐藏层中应用 LeakyReLU 激活函数
x = self.fc2(x) # 将隐藏层输出映射到单个输出值
x = nn.functional.sigmoid(x) # 将输出值映射到 [0, 1] 的范围,表示输入是否是真实图像的概率
return x
这段代码实现了一个基本的生成对抗网络(GAN)。GAN 由两个神经网络组成,一个生成器(Generator)和一个判别器(Discriminator),它们在博弈中对抗地进行训练。生成器从随机噪声中生成假图像,而判别器则试图区分真实图像和生成图像。训练过程目的是使得生成器生成的图像更逼真,从而欺骗判别器,使其将生成的图像和真实图像分类错误。以下是代码实现的具体步骤:
- 导入 torch.nn 库,以使用 PyTorch 提供的深度学习模型库;
- 定义生成器:继承 nn.Module 类并初始化输入随机噪声向量的维度(z_dim)和隐藏层维度(hidden_dim),以及两个全连接层 fc1 和 fc2。在 forward 方法中,通过全连接层将输入向量 z 映射成隐藏层向量 x,然后在隐藏层中应用 leaky ReLU 激活函数,将隐藏层映射到输出图像,最后将输出值映射至 [-1, 1] 的范围并将其展平成图片张量形式;
- 定义判别器:同样继承 nn.Module 类并初始化隐藏层维度(hidden_dim)和两个全连接层 fc1 和 fc2。在 forward 方法中,将输入的二维图片张量展平为向量 x,通过全连接层 fc1 映射到隐藏层,应用 leaky ReLU 激活函数,最后将隐藏层输出通过全连接层 fc2 映射为单个输出值,并将其映射至 [0, 1] 的范围,表示输入是否是真实图像的概率;
- GAN 的训练过程:定义生成器和判别器的优化器(optimizer),设置目标损失函数(loss function),以及迭代训练的循环。每轮迭代都以真实图像或噪声输入生成器作为输入,计算生成器对应的生成图像,将真实的和生成的数据放入判别器中分别计算输出,分别计算判别器对真实数据的损失和对生成数据的损失,然后利用这两个损失更新生成器和判别器的参数。最后的目标是使得生成器生成尽可能逼真的数据,欺骗判别器,从而将其指导生成尽可能逼真的数据。
定义超参数、优化器、损失函数
定义超参数、优化器、损失函数。
# 定义超参数
lr = 0.0002
z_dim = 100
num_epochs = 50
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)
# 定义优化器和损失函数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()
训练网络
# 训练网络,使用给定的 epoch 数量
for epoch in range(num_epochs):
# 遍历数据集的 mini-batches
for i, (real_images, _) in enumerate(data_loader):
# 将真实图像传递给设备
real_images = real_images.to(device)
# 定义真实标签为 1,假的标签为 0
real_labels = torch.ones(real_images.size(0), 1).to(device)
fake_labels = torch.zeros(real_images.size(0), 1).to(device)
# 训练生成器
# 生成随机的噪音 z
z = torch.randn(real_images.size(0), z_dim).to(device)
# 生成 fake_images
fake_images = generator(z)
# 将 fake_images 传递给判别器,得到输出 fake_output
fake_output = discriminator(fake_images)
# 计算生成器的损失值 loss_G
# 需要将 fake_output 与真实标签 real_labels 进行比较
loss_G = criterion(fake_output, real_labels)
# 重置 generator 的优化器,清除梯度
optimizer_G.zero_grad()
# 计算 generator 的梯度
loss_G.backward()
# 更新 generator 参数,使用优化器 optimizer_G
optimizer_G.step()
# 训练判别器
# 将真实图像 real_images 通过判别器得到输出 real_output
real_output = discriminator(real_images)
# 计算判别器对真实图像的损失值 loss_D_real
# 需要将 real_output 与真实标签 real_labels 进行比较
loss_D_real = criterion(real_output, real_labels)
# 计算判别器对生成器生成的 fake_images 的损失值 loss_D_fake
# 需要将 fake_output 与假的标签 fake_labels 进行比较
fake_output = discriminator(fake_images.detach())
loss_D_fake = criterion(fake_output, fake_labels)
# 计算判别器总损失值 loss_D,即 loss_D_real 和 loss_D_fake 的和
loss_D = loss_D_real + loss_D_fake
# 重置 discriminator 的优化器,清除梯度
optimizer_D.zero_grad()
# 计算 discriminator 的梯度
loss_D.backward()
# 更新 discriminator 参数,使用优化器 optimizer_D
optimizer_D.step()
# 输出损失值
# 每100次迭代输出一次
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(data_loader), loss_D.item(), loss_G.item()))
这段代码表示了一个完整的 GAN 利用交替梯度下降算法进行训练的过程。在每个 epoch 内,我们遍历数据集中的所有 mini-batches,并对每个 mini-batch 迭代进行以下步骤:
- 将真实图片
real_images
传递给设备。 - 定义标签,其中真实图像对应的标签为 1,生成器生成的假图像对应的标签为 0。
- 训练生成器。从一个经过随机初始化的噪声向量
z
开始,使用生成器学习生成伪造图片。将生成的图片fake_images
传递给判别器,并使用判别器输出来计算生成器的损失值。这个损失值反映了生成器生成的图片与真实标签之间的差异。 - 重置生成器的优化器,并清空梯度。
- 对生成器的损失值进行反向传播计算梯度。
- 使用优化器来更新生成器的参数,使其在下一个迭代中能够更好地生成伪造图片。
- 训练判别器。使用真实图片来计算判别器对真实图片的损失值,用假图片来计算判别器对伪造图片的损失值,并将这些值相加以得到判别器的总损失。使用反向传播来计算梯度并更新判别器参数。
- 每 100 次迭代输出一次损失值,以跟踪训练的进展。
- 循环在所有训练数据上结束时,一次完整的迭代称为一个 epoch。重复执行此步骤
num_epochs
次来完成 GAN 的训练。
使用生成器生成新图像
在训练结束后,可以使用生成器来生成新图像。
# 生成一些测试数据
# 随机生成一些长度为 z_dim 的、位于设备上的向量 z
z = torch.randn(16, z_dim).to(device)
# 使用生成器从 z 中生成一些假的图片
fake_images = generator(z).detach().cpu()
# 显示生成的图像
# 创建一个图形对象,大小为 4x4 英寸
fig = plt.figure(figsize=(4, 4))
# 在图形对象中创建4x4的网格,以显示输出的16张假图像
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow(fake_images[i][0], cmap='gray')
plt.axis('off')
# 显示绘制的图形
plt.show()
这段代码的目的是在训练 GAN 结束后,使用生成器生成一些随机的假图像进行展示。下面是具体步骤:
- 随机生成一些长度为
z_dim
的、位于设备上的向量z
。 - 使用生成器从
z
中生成一些假的图片。 - 创建一个 4x4 英寸大小的图形对象。
- 在图形对象中创建一个 4x4 的网格,用于显示生成的 16 张图片。
- 使用
plt.imshow()
函数将图像附加到当前子图中。该函数会在 matplotlib 窗口中显示生成的图像。 - 从绘图中删除坐标轴。
- 显示生成的图形。
结论
GAN 是一种强大的生成模型,可以用于生成各种不同类型的数据。本教程展示了如何使用 PyTorch 实现 GAN,并使用 MNIST 数据集进行训练。