GAN网络

定义

GAN(Generative Adversarial Nets),生成对抗网络。顾名思义,它是一个生成模型,其次,它的训练是通过两个模型之间的对抗完成的。
在这里插入图片描述

网络

GAN有两个网络,一个是生成器(generator)简称为G,一个是判别器(discriminator)简称为D。两个网络从0开始博弈,通过对抗来达到一个好的效果。
举个栗子,G是假画贩子,通过模仿真画去作假画,因为刚开始技术还不太成熟,很容易被鉴画师(D)识别出来是假的,但假画贩子坚持不懈,一直学习,使他的画越来越像真画,可这还是会被鉴画师识别出来,因为假画贩子在学习如何模仿真画的时候,鉴画师同时也在学习鉴别,所以两者是相互对抗的,直至最后假画贩子的画几乎让鉴画师都识别不出是真是假。
如下图所示,从V1-V3,所展示的图会越来越接近真实的图。
在这里插入图片描述

GAN的目标函数及流程

在这里插入图片描述
max部分即是D要尽可能识别真实数据和G生成假的数据。
min部分是G要尽可能缩小生成数据和真实数据之间的误差。
根据max部分的目标更新D的参数,提高D的分辨能力。
根据min部分的目标更新G的参数,使G生成的数据更像真实数据。

代码部分(pytorch_mnist)

import torch
import torchvision
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import numpy as np

import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

还是先将需要的包导入,再设置我们是用GPU来跑整个模型。

batch_size = 32
transform= transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
mnist = torchvision.datasets.MNIST('./mnist',train=True,download=False,transform=transform)
dataloader = torch.utils.data.DataLoader(dataset=mnist,
                                        batch_size = batch_size,
                                        shuffle=True)

这就是数据预处理部分,我们用mnist数据集,并设置一个batch是32张图片。这里要注意normalize,mnist是灰度图,只有一个channel,所以它的均值和方差只能是一位,不像RGB是三位

image_size = 28*28 #图像大小
hidden_size = 256 
latent_size = 64 

#定义判别器
D = nn.Sequential(
    nn.Linear(image_size,hidden_size), #输入特征数28*28,输出为256
    nn.LeakyReLU(0.2),                 #进行激活函数非线性映射
    nn.Linear(hidden_size,hidden_size),#进行一个线性映射
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,1),
    nn.Sigmoid()         #二分类问题 鉴别器只用判断是真是假
)
#定义生成器
G = nn.Sequential(
    nn.Linear(latent_size,hidden_size), #输出latent_size 64维的高斯分布,通过线性变换成256
    nn.ReLU(), #激活函数
    nn.Linear(hidden_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,image_size),
    nn.Tanh() #希望生成假的图片在-1—1之间
)

D = D.to(device)
G = G.to(device)

loss_fn =nn.BCELoss()

d_opt = torch.optim.Adam(D.parameters(),lr = 0.0002)
g_opt = torch.optim.Adam(G.parameters(),lr = 0.0002)

在这里将判别器(D)和生成器(G)两个网络模型、优化器和损失函数定义好。

total_steps = len(dataloader)
num_epochs = 200
for epoch in range(num_epochs):
    for i, (image, _) in enumerate(dataloader):
        batch_size = image.size(0)
        image = image.reshape(batch_size, image_size).to(device)  # batch_size,28*28

        real_labels = torch.ones(batch_size, 1).to(device)  # 定义真实label为1
        fake_labels = torch.zeros(batch_size, 1).to(device)  # 定义假的label为0

        real_outputs = D(image)  # 真实图片放入D
        d_loss_real = loss_fn(real_outputs, real_labels)  # 希望D能输出真实的 (0——1)
        real_score = real_outputs

        # 生成fake
        z = torch.randn(batch_size, latent_size).to(device)  # latent
        fake_images = G(z)  # 防止生成网络生成假的图片
        fake_outputs = D(fake_images.detach())  # 不需要G的梯度相关计算

        d_loss_fake = loss_fn(fake_outputs, fake_labels)
        fake_score = fake_outputs

        # 开始优化D
        d_loss = d_loss_real + d_loss_fake #判别器的损失包括判断真的和判断假的
        d_opt.zero_grad() #梯度清零
        d_loss.backward()  #反向传播
        d_opt.step()#更新参数

        # 开始优化G
        z = torch.randn(batch_size, latent_size).to(device) #得到随机噪声
        fake_images = G(z) #将随机噪声输入到生成器,得到假的图片
        outputs = D(fake_images) #经过D判断得到成果
        g_loss = loss_fn(outputs, real_labels) #计算损失

        g_opt.zero_grad() #梯度清零
        g_loss.backward() #反向传播
        g_opt.step()      #更新梯度

        if i % 200 == 0:
            print("[Epochs %d/%d] [batch %d/%d] [D loss:%f] [G loss:%f]"
                  %(epoch, num_epochs, i, total_steps, d_loss.item(), g_loss.item()))

这里我训练的epoch是200个,可能还是稍微要那么一丢丢时间

z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
fake_images = fake_images.view(batch_size,28,28).data.cpu().numpy()
plt.imshow(fake_images[10],cmap = plt.cm.gray)

训练完成后就将图片show出来看看,如下图所示,感觉还是有点不咋地,效果有点问题,自己也是才接触GAN的小白,如有什么问题,望指正。
在这里插入图片描述

参考
https://zhuanlan.zhihu.com/p/72279816
https://blog.csdn.net/u014453898/article/details/95044228

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值