项目源码
import torch import torchvision import torch.nn as nn import torch.optim as optim import torchvision.datasets as dsets import torchvision.transforms as transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt
# 定义生成器 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(100, 256), nn.ReLU(True), nn.Linear(256, 512), nn.ReLU(True), nn.Linear(512, 1024), nn.ReLU(True), nn.Linear(1024, 784), nn.Tanh() )
def forward(self, x): return self.model(x)
# 定义判别器 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(784, 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid() )
def forward(self, x): return self.model(x)
# 超参数设置 batch_size = 100 learning_rate = 0.0002 num_epochs = 5
# 准备数据集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ])
mnist = dsets.MNIST(root='./data', train=True, transform=transform, download=True) data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器 G = Generator() D = Discriminator()
# 损失函数和优化器 criterion = nn.BCELoss() d_optimizer = optim.Adam(D.parameters(), lr=learning_rate) g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
# 训练GAN for epoch in range(num_epochs): for i, (images, _) in enumerate(data_loader): # 将输入展平成一维 images = images.view(batch_size, -1)
# 创建标签 real_labels = torch.ones(batch_size, 1) fake_labels = torch.zeros(batch_size, 1)
# 判别器训练:最大化 log(D(x)) + log(1 - D(G(z))) outputs = D(images) d_loss_real = criterion(outputs, real_labels) real_score = outputs
z = torch.randn(batch_size, 100) fake_images = G(z) outputs = D(fake_images) d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs
d_loss = d_loss_real + d_loss_fake d_optimizer.zero_grad() d_loss.backward() d_optimizer.step()
# 生成器训练:最大化 log(D(G(z))) z = torch.randn(batch_size, 100) fake_images = G(z) outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
g_optimizer.zero_grad() g_loss.backward() g_optimizer.step()
if (i+1) % 200 == 0: print(f'Epoch [{epoch}/{num_epochs}], Step [{i+1}/{len(data_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')
# 每个epoch保存生成的图片 fake_images = fake_images.view(fake_images.size(0), 1, 28, 28) fake_images = fake_images / 2 + 0.5 # 去归一化 grid = torchvision.utils.make_grid(fake_images) plt.imshow(grid.permute(1, 2, 0).detach().cpu().numpy()) plt.show() |
学习笔记
生成器
Python class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(100, 256), nn.ReLU(True), nn.Linear(256, 512), nn.ReLU(True), nn.Linear(512, 1024), nn.ReLU(True), nn.Linear(1024, 784), nn.Tanh() )
def forward(self, x): return self.model(x) |
生成器是GAN框架两大神经网络的其中之一,用来生成fake来迷惑判别器,以上是生成器类的定义。
class Generator(nn.Module)这段代码定义了Generator生成器类,他继承自nn.Module,这是Pytorch的一个基类,用于定义圣经网络的各个部分。
super(Generator, self).__init__()在初始化一个新的Generator对象时,确保他同样继承了父类nn.Module
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
这便是生成器的核心,它将一系列神经网络层按顺序组合在一起。
第一层:nn.Linear(100, 256),第一层接受一个100维的随机噪声向量,这个向量通常是从标准正态分布中随机选取的(提到标准正态分布,总能想到扩散模型,这个随机噪声向量就是从这样一个向量空间中选取的),第一层接收这个100维随机向量后将其线性变化为256维的向量。
中间层:nn.Linear(256, 512),nn.Linear(512, 1024),通过逐步增加向量的维度,生成器网络逐步从噪声中提取特征,并将这些特征转换为越来越复杂的表示。
输出层:nn.Linear(1024, 784),输出层将1024维的向量转化为784层,这是因为数据集MNIST中的每张图像时28×28像素,因此需要784个像素来表示一张图像。
激活函数:nn.ReLU(True),nn.Tanh(),使用激活函数的目的在于使得本来线性变换的神经网络中引入非线性特性。nn.Tanh() 使用 nn.Tanh 作为激活函数,将输出的每个像素值映射到 [-1, 1] 的范围内。Tanh 激活函数通常用于生成图像,以便输出值适合经过标准化处理的图像数据。
def forward(self, x):
return self.model(x)
这段代码定义了前向传播。
判别器
Python class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(784, 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid() )
def forward(self, x): return self.model(x) |
判别器是GAN框架两大神经网络的其中之一,用来判别生成器的输入为0或1,以上是判别器类的定义。
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
这便是判别器神经网络的核心,它通过几个全连接层来进行特征提取压缩,将一个784维的向量压缩到1维,到1维后,判别器才能正常输出,因为判别器的输出是一个标量。
激活函数:nn.LeakyReLU(0.2, inplace=True),nn.Sigmoid()
在每一层之后,使用 nn.LeakyReLU 激活函数,引入非线性并允许一些负值通过,以提高模型的表现力。Leaky ReLU 是 ReLU 的一种变种,它允许小部分负值通过,以防止死神经元问题。
通过 nn.Sigmoid 激活函数映射到 [0, 1] 的范围。Sigmoid 函数是常用的输出激活函数,尤其适用于二分类任务。在 GAN 中,1 表示判别器认为输入是“真实的”,0 表示判别器认为输入是“假的”。
初始化(准备工作)
Python # 超参数设置 batch_size = 100 learning_rate = 0.0002 num_epochs = 5
# 准备数据集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ])
mnist = dsets.MNIST(root='./data', train=True, transform=transform, download=True) data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器 G = Generator() D = Discriminator()
# 损失函数和优化器 criterion = nn.BCELoss() d_optimizer = optim.Adam(D.parameters(), lr=learning_rate) g_optimizer = optim.Adam(G.parameters(), lr=learning_rate) |
超参数设置
batch_size = 100表示每次训练迭代中,模型处理的数据样本数量。较大的 batch_size 通常会加快训练速度,但需要更多的内存。
learning_rate = 0.0002学习率,是优化器中更新模型参数的步长大小。较小的学习率有助于更稳定的训练,而较大的学习率则可能导致模型跳过最优点。
num_epochs = 5表示模型将遍历整个训练数据集的次数。训练的 epoch 数量越多,模型通常会越好,但过多的 epoch 可能导致过拟合。(这里又涉及到过拟合和欠拟合,真是之前印象最深刻的一个概念了
数据预处理
Python transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) |
加载数据集
Python mnist = dsets.MNIST(root='./data', train=True, transform=transform, download=True) data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True) |
初始化生成器和判别器
G = Generator()
D = Discriminator()
定义损失函数和优化器
criterion = nn.BCELoss()这段代码定义了损失函数,BCELoss 是二元交叉熵损失函数,适用于二分类任务。在 GAN 中,判别器输出的是一个标量值(0 表示假,1 表示真),因此使用二元交叉熵来计算损失。
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
使用Adam优化器,它能够灵活调节学习率。
D.parameters() 和 G.parameters() 分别获取判别器和生成器的可训练参数,优化器将根据这些参数更新模型。
lr=learning_rate 指定了每次参数更新的学习率。
训练过程
Python for epoch in range(num_epochs): for i, (images, _) in enumerate(data_loader): # 将输入展平成一维 images = images.view(batch_size, -1)
# 创建标签 real_labels = torch.ones(batch_size, 1) fake_labels = torch.zeros(batch_size, 1)
# 判别器训练:最大化 log(D(x)) + log(1 - D(G(z))) outputs = D(images) d_loss_real = criterion(outputs, real_labels) real_score = outputs
z = torch.randn(batch_size, 100) fake_images = G(z) outputs = D(fake_images) d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs
d_loss = d_loss_real + d_loss_fake d_optimizer.zero_grad() d_loss.backward() d_optimizer.step()
# 生成器训练:最大化 log(D(G(z))) z = torch.randn(batch_size, 100) fake_images = G(z) outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
g_optimizer.zero_grad() g_loss.backward() g_optimizer.step()
if (i+1) % 200 == 0: print(f'Epoch [{epoch}/{num_epochs}], Step [{i+1}/{len(data_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}') |
数据加载与处理
Python for epoch in range(num_epochs): for i, (images, _) in enumerate(data_loader): images = images.view(batch_size, -1) |
最外层循环,训练num_epochs个轮次
data_loader是加载好的数据集,enumerate(data_loader) 使得我们在获取每个 batch 的同时,也能得到当前 batch 的索引 i,这层循环能够以每次以batch_size的规模训练一次数据集。
images.view(batch_size, -1) 将图像展平为一维向量。MNIST 图像原本是 28x28 的二维数组,通过 .view 方法将其变为 784 维的向量。这是因为判别器的输入层需要一个一维的 784 维向量。
创建标签
Python real_labels = torch.ones(batch_size, 1) fake_labels = torch.zeros(batch_size, 1) |
真实标签:real_labels = torch.ones(batch_size, 1) 创建一个形状为 (batch_size, 1) 的全 1 向量,表示真实数据的标签。
伪造标签:fake_labels = torch.zeros(batch_size, 1) 创建一个形状为 (batch_size, 1) 的全 0 向量,表示生成数据的标签。
训练判别器
判别真实数据
Python outputs = D(images) d_loss_real = criterion(outputs, real_labels) real_score = outputs |
将images这个向量放入判别器中,返回一个标量outputs,d_loss_real为损失值,通过二元交叉熵损失函数计算。
生成伪造数据
Python z = torch.randn(batch_size, 100) fake_images = G(z) outputs = D(fake_images) d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs |
z = torch.randn(batch_size, 100) 生成随机噪声向量 z,作为生成器的输入。
fake_images = G(z) 通过生成器生成假图像。
outputs = D(fake_images) 判别器对生成的假图像进行判别。
d_loss_fake 计算伪造数据的损失,使用二元交叉熵损失criterion(outputs, fake_labels)。
更新判别器
Python d_loss = d_loss_real + d_loss_fake d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() |
d_loss 是判别器在当前 batch 上的总损失,它是处理真实数据的损失和处理伪造数据的损失之和。
d_optimizer.zero_grad() 清除判别器的梯度缓存。
d_loss.backward() 反向传播计算梯度。
d_optimizer.step() 更新判别器的参数。
在训练判别器的过程中,损失值有两个部分,真实数据的损失,伪造数据的损失,由这两部分一起构成损失值,共同反向传播,更新判别器。
训练生成器
Python z = torch.randn(batch_size, 100) fake_images = G(z) outputs = D(fake_images)
g_loss = criterion(outputs, real_labels) |
使用随机噪声z来生成伪造图像,判别器判断伪造图像,将返回值与伪造标签计算,得到损失值。
Python g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() |
更新生成器。
训练日志输出
Python if (i+1) % 200 == 0: print(f'Epoch [{epoch}/{num_epochs}], Step [{i+1}/{len(data_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}') |
可视化
Python fake_images = fake_images.view(fake_images.size(0), 1, 28, 28) fake_images = fake_images / 2 + 0.5 # 去归一化 grid = torchvision.utils.make_grid(fake_images) plt.imshow(grid.permute(1, 2, 0).detach().cpu().numpy()) plt.show() |
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)这句代码将原本输出的batch_size*728的二维张量,转化为batch_size*1*28*28的四维张量。这样,fake_images 就被重新构造成了可以直接表示图像的形状,每个图像大小为 28x28 像素。
fake_images = fake_images / 2 + 0.5 去归一化,生成器的输出经过 tanh 激活函数,输出值范围在 [-1, 1] 之间。这一步将输出值转换回 [0, 1] 的范围(适合可视化)。具体操作是将每个像素值除以 2,再加上 0.5,得到的结果就是将 [-1, 1] 的值重新映射到 [0, 1] 的范围。
grid = torchvision.utils.make_grid(fake_images):
torchvision.utils.make_grid:这个方法将多个图像组合成一个网格(grid)图像,以便能够在一张图片中显示多个图像。常用于生成器生成图像的可视化。
例如,如果 fake_images 包含 64 张 28x28 的图像,那么 make_grid 可能会将它们排列成一个 8x8 的网格图像,也就是batch_size个图像。
plt.imshow(grid.permute(1, 2, 0).detach().cpu().numpy()):
grid.permute(1, 2, 0):permute 方法重新排列张量的维度顺序。在 PyTorch 中,图像张量的默认顺序是 (C, H, W),即通道数、图像高度、图像宽度。而 matplotlib(plt.imshow) 期望的图像顺序是 (H, W, C),即高度、宽度、通道数。permute(1, 2, 0) 将维度从 (C, H, W) 转换为 (H, W, C),以便可以正确显示图像。
detach().cpu():detach 方法用于从计算图中分离张量,这样在后续的操作中,梯度不再被跟踪。cpu() 方法将张量从 GPU 移动到 CPU(如果在 GPU 上运行的话),因为 matplotlib 只能处理在 CPU 上的张量。
numpy():numpy() 方法将张量转换为 NumPy 数组,plt.imshow 需要 NumPy 数组作为输入。
plt.show():显示图像。