基于MNIST的生成对抗样本(GAN)
一、导入MNIST数据集(训练模块中)
dataset=torchvision.datasets.MNIST("mnist_data",train=True,download=True)
参数:目录,是否是train模式,是否在线下载,传入transform(因为原来是utf8格式的,要转换成浮点型)
了解MNIST数据集
1、输出数据集长度
print(len(dataset))
2、查看数据集中具体数据(前五个)
for i in range(len(dataset)):
if i<5:
print(dataset[i])
else:
break
查看结果是5个image格式的对象信息
shape()函数:用于计算数组的行数列数。torchvision.datasets.MNIST:产生的image对象前面是数据后面是标签,所以dataset[][]为二维。产生报错信息:image 对象没有shape模式,所以需要调用一下数据集中的transform
添加如下代码:
transform=torchvision.transforms.Compose(
[torchvision.transforms.Resize(28),
torchvision.transforms.ToTensor(),
])
再次运行代码,会得到数据的具体参数:12828
二、生成器的大概框架
class Generator(nn.Module):
def __init__(self,in_dim):
super(Generator,self).__init__()
self.model=nn.Sequential(
nn.Linear(in_dim, 64),
torch.nn.ReLU(inplace=True),
nn.Linear(64, 128),
torch.nn.ReLU(inplace=True),
nn.Linear(128, 256),
torch.nn.ReLU(inplace=True),
nn.Linear(256, 512),
torch.nn.ReLU(inplace=True),
nn.Linear(512, 1024),
torch.nn.ReLU(inplace=True),
nn.Linear(1024,torch.prod(image_size,dtype=torch.int32)),
nn.Tanh(),
)
def forward(self,z):
output=self.model(z)
image=output.reshape(z.shape[0],*image_size)
return image
生成器框架分析:
- 输入高斯噪声z,输出生成的图像
- 整体由两部分函数构成,一部分构建网络,一部分进行连接使用
三、判别器的整体框架
class Discriminator(nn.Module):
#输入一张照片输出概率
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(np.prod(image_size, dtype=torch.int32), 1024),
torch.nn.ReLU(inplace=True),
nn.Linear(1024, 512),
torch.nn.ReLU(inplace=True),
nn.Linear(512, 256),
torch.nn.ReLU(inplace=True),
nn.Linear(256, 128),
torch.nn.ReLU(inplace=True),
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, image):
prob = self.model(image.reshape(image.shape[0],-1))
return prob
判别器框架分析:
- 输入生成的图像,输出为判别图像概率
- 整体由两部分函数构成,一部分构建网络,一部分进行连接使用
四、数据训练(全部代码)
'''
实验:基于MNIST(手写数字识别)实现生成对抗网络(GAN)
'''
import torch
import torch.nn as nn
#生成器类的实现
import torchvision.datasets
from torch.utils.data import DataLoader
import numpy as np
image_size=[1,28,28]#定义常量(图片大小)
num_epoch=100
latent_dim=32
batch_size=32
class Generator(nn.Module):
#输入高斯噪声,输出图像
# in_dim: 高斯噪声z的输入维度
def __init__(self):#进行一些模块的定义
#对父类实例化,集成nn.Module
super(Generator,self).__init__()
#定义一个model,这个model可以用很多DNN去做,我们使用一个容器nn.Sequential()将他们装起来
#容器中需要传入一个个的model的,我们不需要用列表,只需要传入
self.model = nn.Sequential(
nn.Linear(latent_dim, 64),
torch.nn.ReLU(inplace=True),
nn.Linear(64, 128),
torch.nn.ReLU(inplace=True),
nn.Linear(128, 256),
torch.nn.ReLU(inplace=True),
nn.Linear(256, 512),
torch.nn.ReLU(inplace=True),
nn.Linear(512, 1024),
torch.nn.ReLU(inplace=True),
nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
nn.Tanh(),
)
#由于生成器的输入是一个高斯噪声z(文中随机的高斯变量)
def forward(self,z):#将定义函数进行连接
#z的形状:[batchsize,1*28*28],定义为两维,一般将1*28*28定义为任意维度latent_dim
#将z传到model中
output=self.model(z)
#将矩阵映射为图像
# 参数:z的维度batchsize,图像大小即位image_size
image=output.reshape(z.shape[0],*image_size)#由于image_size为列表,所以添加星号表示为将列表元素独立出来分别传入
#得到生成器的输出
return image
#判别器类的实现
class Discriminator(nn.Module):
#输入一张照片输出概率
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(np.prod(image_size, dtype=np.int32), 1024),
torch.nn.ReLU(inplace=True),
nn.Linear(1024, 512),
torch.nn.ReLU(inplace=True),
nn.Linear(512, 256),
torch.nn.ReLU(inplace=True),
nn.Linear(256, 128),
torch.nn.ReLU(inplace=True),
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, image): # 将定义函数进行连接
# image图像的形状:[batchsize,1,28,28]
# 将图像传到model中
prob = self.model(image.reshape(image.shape[0],-1))
# 得到生成器的输出,返回概率
return prob
#训练Training
#对MNIST数据集API进行导入(直接下载)
#参数介绍:目录,是否是train模式,是否在线下载,还要传入transform(因为原来是utf8格式的,要转换成浮点型)
dataset=torchvision.datasets.MNIST("mnist_data",train=True,download=True,
transform=torchvision.transforms.Compose(#可以包含很多操作
[
torchvision.transforms.Resize(28),#调整图像大小
torchvision.transforms.ToTensor(),#将utf-8转换为浮点型
# torchvision.transforms.Normalize(mean=[0.5],std=[0.5])#归一化,需要提前计算均值和方差
]))
#将导入的数据引入downloder中,需要传入参数dataset,batch_size(32或64均可)和shuffle
#dataloader的作用就是把dataset中的每个样本构成一个mini_barch,后面进行批训练
dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
#构建优化器2个,分别对生成器和判别起的参数进行优化
#实例化一个generator和discriminator
generator = Generator()
discriminator = Discriminator()
# 参数params可迭代的参数,其实是对函数调用即可得到.第二个为学习率
g_optimizer=torch.optim.Adam(generator.parameters(),lr=0.0001,betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer=torch.optim.Adam(discriminator.parameters(),lr=0.0001,betas=(0.4, 0.8), weight_decay=0.0001)
# 损失函数定义:BCE二项的交叉熵函数
loss_fn=nn.BCELoss()
for epoch in range(num_epoch):#对epoch进行循环
# 对dataloader进行遍历,使用enumerate()函数的返回参数是index(索引)和sample(一个元祖,图片和标签)
for i,mini_batch in enumerate(dataloader):
# 对mini_batch进行解析
gt_images,_=mini_batch#我们只需要mini_batch中的图片(ground_true真实图片),不需要标签
#随机生成变量z服从正态分布,形状为barch_size,z的维度1*28*28,此处定义为latent_dim
z=torch.randn(batch_size,latent_dim)#z的大小为batch_size*latent_dim
#将z喂入生成器得到预期图像
pred_images=generator(z)
#把预期图像送入判别器中进行概率预测
# discriminator(pred_images)
#对生成器进行优化
#梯度置0
g_optimizer.zero_grad()
#计算损失函数,通过图像概率和目标计算损失
# target=torch.ones(batch_size,1)
g_loss=loss_fn(discriminator(pred_images),torch.ones(batch_size,1))
#计算梯度
g_loss.backward()
#参数优化(更新)
g_optimizer.step()
#对判别器进行优化
d_optimizer.zero_grad()
#调用detach()函数对预测图像的梯度进行隔离,不需要计算它的梯度
# d_loss = 0.5*(loss_fn(discriminator(gt_images), torch.ones(batch_size,1))+loss_fn(discriminator(pred_images.detach()),torch.zeros(batch_size,1)))
real_loss = loss_fn(discriminator(gt_images), torch.ones(batch_size,1))
fales_loss =loss_fn(discriminator(pred_images.detach()),torch.zeros(batch_size,1))
d_loss=0.5*real_loss+fales_loss#当两个loss都在不断下降并相差不大,则训练成功
#计算梯度
d_loss.backward()
#参数优化(更新)
d_optimizer.step()
# #保存照片结果(每隔1000步对所生成的照片进行保存)
# if i % 1000==0:
# # 参数:存储的图片数据,存储的文件名称,
# for index,image in enumerate(pred_images):
# torchvision.utils.save_image(image,f"image_{index}.png")
#
#
# if i % 50 == 0:
# print(f"step:{len(dataloader)*epoch+i}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fales_loss.item()}")
#
五、存在问题
仅学习了解了GAN的整体框架,没有对生成器和判别器使用的网络进行设计
可以自行百度进行设计获得更好的效果