一、前言
主要参考:
日月光华对GAN的讲解和演示
https://www.bilibili.com/video/BV1xm4y1X7KZ/?p=8&spm_id_from=pageDriver&vd_source=7d9525fad541a2d64d7edb7ee9f5fefa
实验环境:
pytorch 1.10.2+pycharm
二、代码
#自己写一遍gan网络
#用pytorch实现,数据集使用手写数字集
import torch
import torch.nn as nn#神经网络工具箱
#torchvision
from torchvision import transforms#数据预处理
from torchvision import datasets#mnist数据集
import numpy as np#
import matplotlib.pyplot as plt#绘图包
#确定超参数
batch_size=64
epochs=100
#1.首先加载数据集
#数据集预处理
#对数据集做归一化(-1,1)
img_transforms=transforms.Compose([
transforms.ToTensor(),#转换成tensor格式
transforms.Normalize(0.5,0.5)
])
#下载数据集
mnist=datasets.MNIST(root='./data/',train=True,transform=img_transforms,download=False)
#数据加载
dataLoader=torch.utils.data.DataLoader(mnist,batch_size=batch_size,shuffle=True)
#2.定义生成器
#搞明白输入和输出分布是什么,输入是100维的噪声(服从正态分布),输出是(1,28,28)的图片
class Generator(nn.Module):#继承
def __init__(self):#构造器
super(Generator,self).__init__()#super()调用父类
#生成器的网络
self.net=nn.Sequential(#展平
nn.Linear(100,256),
nn.ReLU(),
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,28*28),
nn.Tanh()#输出范围[-1,1]
)
def forward(self,z):#z为长度一百的噪声输入
img=self.net(z)
img=img.view(-1,28,28,1)#(28*28)->(1,28,28)
return img
#3.定义判别器
#输入为(1,28,28)的图片,输出是二分类的概率值,输出使用sigmoid激活
#判别器一般使用LeakyReLu激活函数,带一点斜率,初始为0.2
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
#判别器网络
self.net=nn.Sequential(
nn.Linear(28*28,512),
nn.LeakyReLU(),
nn.Linear(512,256),
nn.LeakyReLU(),
nn.Linear(256,1),
nn.Sigmoid()#输出在[0,1]之间
)
def forward(self,img):#img为输入判别器的图片(1.28,28)
x=img.view(-1,28*28)
x=self.net(x)
return x
#4.初始化模型
# 定义运行设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#在设备上运行
gen=Generator().to(device)
dis=Discriminator().to(device)
#定义优化器
d_optim=torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optim=torch.optim.Adam(gen.parameters(),lr=0.0001)
#用BCELoss计算交叉熵损失(二分类)
loss_fn=nn.BCELoss()
#5.绘图函数
# 查看生成的图片
def gen_img_plot(model,test_input):## 将噪声放进去生成器中
prediction=np.squeeze(model(test_input).detach().cpu().numpy())#保留梯度
fig = plt.figure(figsize=(4,4))#16张图片
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow((prediction[i]+1)/2)
plt.axis('off')
plt.show()
test_input= torch.randn(batch_size,100,device=device)
#6.gan的训练
D_loss=[]
G_loss=[]
#循环多少个epoch
for epoch in range(epochs):
d_epoch_loss=0
g_epoch_loss = 0
count=len(dataLoader)#总样本个数(一个epoch的个数)
for step,(img,_) in enumerate(dataLoader):#enumerate对数据编号,img,_只取图片,不取标签,img的格式为(64,1,28,28)
img=img.to(device)
size=img.size(0)
#在真实图片上计算判别器损失
d_optim.zero_grad()
real_output=dis(img)
d_real_loss=loss_fn(real_output,torch.ones_like(real_output))#在真实图片上的输出尽可能接近1
d_real_loss.backward()
#在生成图片上计算判别器损失
random_noise=torch.randn(size,100,device=device)
gen_img=gen(random_noise)
fake_output=dis(gen_img.detach())#这里只更新判别器,所以要截断梯度
d_fake_loss=loss_fn(fake_output,torch.zeros_like(fake_output))
d_fake_loss.backward()
#判别器的总和损失
d_loss=d_fake_loss+d_real_loss
d_optim.step()#优化
#计算生成器损失
g_optim.zero_grad()
fake_output=dis(gen_img)
g_loss=loss_fn(fake_output,torch.ones_like(fake_output))
g_loss.backward()
g_optim.step()
#损失绘图
with torch.no_grad():#不更新梯度
d_epoch_loss+=d_loss#计算一个epoch的判别器总和损失
g_epoch_loss+=g_loss
with torch.no_grad():
d_epoch_loss/=count#计算一个epoch的判别器平均损失
g_epoch_loss/=count
D_loss.append(d_epoch_loss.item())#D_loss列表里面存放所有的损失值
G_loss.append(g_epoch_loss.item())
print('epoch:',epoch)
gen_img_plot(gen,test_input)