'''
GAN基础模型搭建
利用MNIST手写字母数据集进行基础GAN程序编写
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
# 一、加载训练数据集
# 数据归一化(-1,1),因为生成器的激活函数是tanh,数据也是(-1,1)
transform = transforms.Compose([
# 将shape为(H,W,C)的数组或者.img转化为shape为(C,H,W)的tensor
# 将数据转化为张量,并尽心(0,1)归一化,并且CHW的格式
transforms.ToTensor(),
# 将数据从(0,1)之间转化到(-1,1)之间,归一化
transforms.Normalize(0.5, 0.5),
])
'''
transforms.Compose(): 将多个预处理依次累加在一起,每次执行transform都会依次执行其中包含的多个预处理程序
transforms.ToTensor():在做数据归一化之前必须要把PIL Image转成Tensor
transforms.Normalize(0.5, 0.5):这里的两个0.5分别表示对张量进行归一化的 全局平均值和方差,
因为图像是灰色的只有一个通道,所以分别指定一了一个值,
如果有多个通道,需要有多个数字,
如3个通道,就应该是Normalize([m1, m2, m3], [n1, n2, n3])
'''
# 加载数据集MNIST手写字母
train_ds = torchvision.datasets.MNIST(
root='../data/MNIST',
train=True,
transform=transform,
download=False
)
test_ds = torchvision.datasets.MNIST(
root='../data/MNIST',
train=False,
transform=transform,
download=False
)
'''
root :需要下载至地址的根目录位置
train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt; 默认是True
transform:一系列作用在PIL图片上的转换操作,返回一个转换后的版本
download:是否下载到 root指定的位置,如果指定的root位置已经存在该数据集,则不再下载
'''
train_loader = data.DataLoader(
dataset=train_ds,
batch_size=128,
shuffle=True
)
test_loader = data.DataLoader(
dataset=test_ds,
batch_size=128,
shuffle=True
)
'''
PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。
该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
torch.utils.data.DataLoader(object)的可用参数如下:
dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。
batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)
shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False)
'''
# 二、搭建生成器模型
# 输入是长度为100的噪声(正态分布)
# 由于该数据集的CHW是(1,28,28),所以我们的输出也应该为(1,28,28)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 28 * 28),
# 最后一层使用tanh激活,使其分布在(-1,1)之间
nn.Tanh()
)
# 定义前向传播
def forward(self, x): # x表示长度为100的噪声输入
x = self.main(x)
# 转化成图片的形式
x = x.view(-1, 28, 28)
return x
'''
生成器结构:
输出是长度为100的噪声(正态分布随机数)
输出为(1,28,28)的图片
linear 1:100---256
linear 2: 256--512
linear 3:512--784(28*28)
reshape: 784---(1,28,28)
'''
# 三、搭建判别器模型
# 输入为(1,28,28)的图片,输出为二分类的概率值,输出使用sigmoid函数激活
# BCEloss计算交叉熵
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28 * 28, 512),
# 判别器一般使用leakyrelu函数激活
# nn.LeakyReLU f(x):如果X>0输出0,如果X<0输出a*x,a表示一个很小的斜率,比如0.1
nn.LeakyReLU(),
nn.Linear(512, 256),
nn.LeakyReLU(),
nn.Linear(256, 1),
# 判别器最后输出一般使用sigmoid的激活0-1
nn.Sigmoid()
)
# 定义前向传播
def forward(self, x): # x表示长度为28*28的图片输入
x = x.view(-1, 28 * 28)
x = self.main(x)
return x
'''
输入为为(1,28,28)的图片
输出为二分类(0/1)的概率值,
linear 1:784(28*28)--512
linear 2:512--256
linear 3:256--1
reshape: 1---(0 or 1)
'''
# 四、设备的配置
# 判断设备是否可以调用GPU,如果没有GPU则使用CPU进行计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 初始化生成器和判别器,将其调用到响应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)
# 分别设置生成器和辨别器的优化函数
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
# 设置损失函数,BCELoss计算交叉熵损失
loss_func = nn.BCELoss()
# 设置绘图函数
def gen_img_plot(model, test_input):
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4, 4))
plt.close()
for i in range(test_input.size(0)):
plt.subplot(4, 4, i + 1)
plt.imshow((prediction[i] + 1) / 2) # 将输出恢复在0-1之间的数值,并进行绘图
plt.axis('off')
# plt.show()
# 图像显示暂停一秒钟
plt.pause(1)
# 测试输入,16个长度为100的随机数
test_input = torch.randn(16, 100, device=device)
# 五、GAN的训练
D_loss = []
G_loss = []
# 训练循环
for epoch in range(20):
# 初始化损失值
d_epoch_loss = 0
g_epoch_loss = 0
# 返回批次数
batch_size = len(train_loader)
# 对数据集进行迭代
for step, (x, y) in enumerate(train_loader):
# 把数据传递到设备上
x = x.to(device)
# x的第一位是size,获取批次的大小
size = x.size(0)
# 根据x的size生成对于数量且长度为100的随机数列,并传递到设备上
noise = torch.randn(size, 100, device=device)
# 判别器训练(包含真实图片的损失和生成图片的损失两部分),损失的构建与优化
# 真实图片损失
# 梯度归零
d_optim.zero_grad()
# 判别器输入真实图片,real_output是对真实图片的预测结果
real_output = dis(x)
# 判别器对于真实图片产生的损失
d_real_loss = loss_func(
real_output,
# 期望真实图片的预测结果均为1
torch.ones_like(real_output)
)
# 计算梯度
d_real_loss.backward()
# 生成图片损失
# 使用生成器生成的图片训练判别器,期望判定为假,此时图片使用生成器计算得来的
# 得到生成器生成的图片
gen_x = gen(noise)
# 喂给判别器时要截断梯度,防止更新时把生成器也更新了
# 判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度不会再传递到生成器模型中、
fake_output = dis(gen_x.detach())
# 判别器对于生成器生成图片产生的损失
d_fake_loss = loss_func(
fake_output,
# 期望生成图片的预测结果均为0
torch.zeros_like(fake_output)
)
# 计算梯度
d_fake_loss.backward()
# 判别器总的损失
d_loss = d_fake_loss + d_real_loss
# 优化判别器
d_optim.step()
# 生成器训练
# 将生成器中的梯度置零
g_optim.zero_grad()
# 判别器输入生成的图片
fake_output = dis(gen_x)
# 生成器对于判别器中输入生成器生成图片产生的损失
g_loss = loss_func(
fake_output,
# 生成器的损失,对于生成器是期望生成图片的预测结果全1
torch.ones_like(fake_output)
)
# 计算梯度
g_loss.backward()
# 优化生成器
g_optim.step()
# 累计每一个批次的loss
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
# 求每一epoch训练的平均损失
with torch.no_grad():
d_epoch_loss /= batch_size
g_epoch_loss /= batch_size
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
# 输出当前训练的epoch次数
print(f'epoch:{epoch + 1}')
# 利用绘图函数gen_img_plot绘制16个测试输入的生成器生成结果
gen_img_plot(gen, test_input)
参考:
GAN代码实战和原理精讲 PyTorch代码进阶 最简明易懂的GAN生成对抗网络入门课程 使用PyTorch编写GAN实例 2021.12最新课程