'''
GAN基础模型搭建
利用faces动漫头像数据集进行基础GAN程序编写
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import tqdm
# 一、加载训练数据集
# 数据归一化(-1,1),因为生成器的激活函数是tanh,数据也是(-1,1)
transform = transforms.Compose([
# 将图片转换为96*96的尺寸
transforms.Resize(96),
# 将图片中心截取为96*96的尺寸
transforms.CenterCrop(96),
# 将shape为(H,W,C)的数组或者.img转化为shape为(C,H,W)的tensor
# 将数据转化为张量,并尽心(0,1)归一化,并且CHW的格式
transforms.ToTensor(),
# 将数据从(0,1)之间转化到(-1,1)之间,归一化
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
'''
transforms.Compose(): 将多个预处理依次累加在一起,每次执行transform都会依次执行其中包含的多个预处理程序
transforms.Resize():调整图片的尺寸
transforms.CenterCrop():以图片中心为基准,按规定大小截取图片
transforms.ToTensor():在做数据归一化之前必须要把PIL Image转成Tensor
transforms.Normalize(0.5, 0.5):这里的两个0.5分别表示对张量进行归一化的 全局平均值和方差,
因为图像是灰色的只有一个通道,所以分别指定一了一个值,
如果有多个通道,需要有多个数字,
如3个通道,就应该是Normalize([m1, m2, m3], [n1, n2, n3])
'''
# 加载数据集faces动漫头像
train_data_dir = '../data/faces' # 图像数据储存路径
train_ds = ImageFolder(
train_data_dir,
transform=transform
)
test_ds = ImageFolder(
train_data_dir,
transform=transform
)
'''
ImageFolder:设置图片的所在文件夹和图片的调整函数
'''
train_data_loader = data.DataLoader(
dataset=train_ds,
batch_size=16,
shuffle=True,
num_workers=0
)
test_loader = data.DataLoader(
dataset=test_ds,
batch_size=16,
shuffle=True
)
'''
PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。
该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
torch.utils.data.DataLoader(object)的可用参数如下:
dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。
batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)
shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False)
'''
# 二、搭建生成器模型
# 输入是长度为100的噪声(正态分布)
# 由于该数据集的CHW是(3,96,96),所以我们的输出也应该为(3,96,96)
class NetG(nn.Module):
"""
生成器定义
"""
def __init__(self):
super(NetG, self).__init__()
# 生成器feature map数
ngf = 64
self.main = nn.Sequential(
# 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
nn.ConvTranspose2d(100, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# 上一步的输出形状:(ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# 上一步的输出形状: (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# 上一步的输出形状: (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# 上一步的输出形状:(ngf) x 32 x 32
nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
nn.Tanh() # 输出范围 -1~1 故而采用Tanh
# 输出形状:3 x 96 x 96
)
def forward(self, input):
return self.main(input)
'''
生成器结构:
输出是长度为100的噪声(正态分布随机数)
输出为(3,96,96)的图片
'''
# 三、搭建判别器模型
# 输入为(3,96,96)的图片,输出为二分类的概率值,输出使用sigmoid函数激活
# BCEloss计算交叉熵
class NetD(nn.Module):
"""
判别器定义
"""
def __init__(self):
super(NetD, self).__init__()
# 判别器feature map数
ndf = 64
self.main = nn.Sequential(
# 输入 3 x 96 x 96
nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
# 判别器一般使用leakyrelu函数激活
# nn.LeakyReLU f(x):如果X>0输出0,如果X<0输出a*x,a表示一个很小的斜率,比如0.1
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
# 判别器最后输出一般使用sigmoid的激活0-1
nn.Sigmoid()
# 输出一个数(概率)
)
def forward(self, input):
return self.main(input).view(-1)
'''
输入为为(3,96,96)的图片
输出为二分类(0/1)的概率值,
'''
# 四、设备的配置
# 判断设备是否可以调用GPU,如果没有GPU则使用CPU进行计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 初始化生成器和判别器,将其调用到响应的设备上
gen = NetG().to(device)
dis = NetD().to(device)
# 分别设置生成器和辨别器的优化函数
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
# 设置损失函数,BCELoss计算交叉熵损失
loss_func = nn.BCELoss()
# 设置绘图函数
def gen_img_plot(model, test_input):
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
# print(prediction.shape)
fig = plt.figure(figsize=(4, 4))
plt.close()
for i in range(test_input.size(0)):
plt.subplot(4, 4, i + 1)
# 将输出恢复在0-1之间的数值,并进行绘图
# plt.imshow((prediction[i] + 1) / 2)
'''
由于使用matplotlib显示彩色图像需要数据的维度为[width, height, channel],就是96 * 96 * 3
而tensor的维度为 3 * 224 * 224,需要用.T对Tensor进行维度顺序颠倒
才能保证画图顺利
'''
plt.imshow(((prediction[i] + 1) / 2).T)
plt.axis('off')
# plt.show()
# 图像显示暂停一秒钟
plt.pause(1)
# 测试输入,16个长度为100的随机数
test_input = torch.randn(16, 100, 1, 1, device=device)
# 五、GAN的训练
D_loss = []
G_loss = []
# 训练循环
# Generator 训练1次,Discriminator训练5次 总迭代100次
for epoch in range(400):
# 初始化损失值
d_epoch_loss = 0
g_epoch_loss = 0
# 返回批次数
batch_size = len(train_data_loader)
# 对数据集进行迭代
for step, (x, y) in tqdm.tqdm(enumerate(train_data_loader)):
# 把数据传递到设备上
x = x.to(device)
# x的第一位是size,获取批次的大小
size = x.size(0)
# 根据x的size生成对于数量且长度为100的随机数列,并传递到设备上
noise = torch.randn(size, 100, 1, 1, device=device)
# 判别器训练(包含真实图片的损失和生成图片的损失两部分),损失的构建与优化
if step % 1 == 0:
# 真实图片损失
# 梯度归零
d_optim.zero_grad()
# 判别器输入真实图片,real_output是对真实图片的预测结果
real_output = dis(x)
# 判别器对于真实图片产生的损失
d_real_loss = loss_func(
real_output,
# 期望真实图片的预测结果均为1
torch.ones_like(real_output)
)
# 计算梯度
d_real_loss.backward()
# 生成图片损失
# 使用生成器生成的图片训练判别器,期望判定为假,此时图片使用生成器计算得来的
# 得到生成器生成的图片
gen_x = gen(noise)
# 喂给判别器时要截断梯度,防止更新时把生成器也更新了
# 判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度不会再传递到生成器模型中、
fake_output = dis(gen_x.detach())
# 判别器对于生成器生成图片产生的损失
d_fake_loss = loss_func(
fake_output,
# 期望生成图片的预测结果均为0
torch.zeros_like(fake_output)
)
# 计算梯度
d_fake_loss.backward()
# 判别器总的损失
d_loss = d_fake_loss + d_real_loss
# 优化判别器
d_optim.step()
# 生成器训练
if step % 5 == 0:
# 将生成器中的梯度置零
g_optim.zero_grad()
# 判别器输入生成的图片
fake_output = dis(gen_x)
# 生成器对于判别器中输入生成器生成图片产生的损失
g_loss = loss_func(
fake_output,
# 生成器的损失,对于生成器是期望生成图片的预测结果全1
torch.ones_like(fake_output)
)
# 计算梯度
g_loss.backward()
# 优化生成器
g_optim.step()
# 累计每一个批次的loss
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
# 求每一epoch训练的平均损失
with torch.no_grad():
d_epoch_loss /= batch_size
g_epoch_loss /= batch_size
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
# 输出当前训练的epoch次数
print(f'epoch:{epoch + 1}')
# 利用绘图函数gen_img_plot绘制16个测试输入的生成器生成结果
gen_img_plot(gen, test_input)
参考