定义
GAN(Generative Adversarial Nets),生成对抗网络。顾名思义,它是一个生成模型,其次,它的训练是通过两个模型之间的对抗完成的。
网络
GAN有两个网络,一个是生成器(generator)简称为G,一个是判别器(discriminator)简称为D。两个网络从0开始博弈,通过对抗来达到一个好的效果。
举个栗子,G是假画贩子,通过模仿真画去作假画,因为刚开始技术还不太成熟,很容易被鉴画师(D)识别出来是假的,但假画贩子坚持不懈,一直学习,使他的画越来越像真画,可这还是会被鉴画师识别出来,因为假画贩子在学习如何模仿真画的时候,鉴画师同时也在学习鉴别,所以两者是相互对抗的,直至最后假画贩子的画几乎让鉴画师都识别不出是真是假。
如下图所示,从V1-V3,所展示的图会越来越接近真实的图。
GAN的目标函数及流程
max部分即是D要尽可能识别真实数据和G生成假的数据。
min部分是G要尽可能缩小生成数据和真实数据之间的误差。
根据max部分的目标更新D的参数,提高D的分辨能力。
根据min部分的目标更新G的参数,使G生成的数据更像真实数据。
代码部分(pytorch_mnist)
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
还是先将需要的包导入,再设置我们是用GPU来跑整个模型。
batch_size = 32
transform= transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
mnist = torchvision.datasets.MNIST('./mnist',train=True,download=False,transform=transform)
dataloader = torch.utils.data.DataLoader(dataset=mnist,
batch_size = batch_size,
shuffle=True)
这就是数据预处理部分,我们用mnist数据集,并设置一个batch是32张图片。这里要注意normalize,mnist是灰度图,只有一个channel,所以它的均值和方差只能是一位,不像RGB是三位
image_size = 28*28 #图像大小
hidden_size = 256
latent_size = 64
#定义判别器
D = nn.Sequential(
nn.Linear(image_size,hidden_size), #输入特征数28*28,输出为256
nn.LeakyReLU(0.2), #进行激活函数非线性映射
nn.Linear(hidden_size,hidden_size),#进行一个线性映射
nn.LeakyReLU(0.2),
nn.Linear(hidden_size,1),
nn.Sigmoid() #二分类问题 鉴别器只用判断是真是假
)
#定义生成器
G = nn.Sequential(
nn.Linear(latent_size,hidden_size), #输出latent_size 64维的高斯分布,通过线性变换成256
nn.ReLU(), #激活函数
nn.Linear(hidden_size,hidden_size),
nn.ReLU(),
nn.Linear(hidden_size,image_size),
nn.Tanh() #希望生成假的图片在-1—1之间
)
D = D.to(device)
G = G.to(device)
loss_fn =nn.BCELoss()
d_opt = torch.optim.Adam(D.parameters(),lr = 0.0002)
g_opt = torch.optim.Adam(G.parameters(),lr = 0.0002)
在这里将判别器(D)和生成器(G)两个网络模型、优化器和损失函数定义好。
total_steps = len(dataloader)
num_epochs = 200
for epoch in range(num_epochs):
for i, (image, _) in enumerate(dataloader):
batch_size = image.size(0)
image = image.reshape(batch_size, image_size).to(device) # batch_size,28*28
real_labels = torch.ones(batch_size, 1).to(device) # 定义真实label为1
fake_labels = torch.zeros(batch_size, 1).to(device) # 定义假的label为0
real_outputs = D(image) # 真实图片放入D
d_loss_real = loss_fn(real_outputs, real_labels) # 希望D能输出真实的 (0——1)
real_score = real_outputs
# 生成fake
z = torch.randn(batch_size, latent_size).to(device) # latent
fake_images = G(z) # 防止生成网络生成假的图片
fake_outputs = D(fake_images.detach()) # 不需要G的梯度相关计算
d_loss_fake = loss_fn(fake_outputs, fake_labels)
fake_score = fake_outputs
# 开始优化D
d_loss = d_loss_real + d_loss_fake #判别器的损失包括判断真的和判断假的
d_opt.zero_grad() #梯度清零
d_loss.backward() #反向传播
d_opt.step()#更新参数
# 开始优化G
z = torch.randn(batch_size, latent_size).to(device) #得到随机噪声
fake_images = G(z) #将随机噪声输入到生成器,得到假的图片
outputs = D(fake_images) #经过D判断得到成果
g_loss = loss_fn(outputs, real_labels) #计算损失
g_opt.zero_grad() #梯度清零
g_loss.backward() #反向传播
g_opt.step() #更新梯度
if i % 200 == 0:
print("[Epochs %d/%d] [batch %d/%d] [D loss:%f] [G loss:%f]"
%(epoch, num_epochs, i, total_steps, d_loss.item(), g_loss.item()))
这里我训练的epoch是200个,可能还是稍微要那么一丢丢时间
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
fake_images = fake_images.view(batch_size,28,28).data.cpu().numpy()
plt.imshow(fake_images[10],cmap = plt.cm.gray)
训练完成后就将图片show出来看看,如下图所示,感觉还是有点不咋地,效果有点问题,自己也是才接触GAN的小白,如有什么问题,望指正。
参考
https://zhuanlan.zhihu.com/p/72279816
https://blog.csdn.net/u014453898/article/details/95044228