浅谈生成对抗网络GAN

简介

生成对抗网络(Generative adversarial networks)是深度学习领域的一个重要生成模型,当然还有其他的生成模型,比如VAE和其他GAN变种模型 。
为什么叫做生成对抗网络。是因为GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)。

下面以生产图片为例进行分析

生成网络

接收一个随机的噪音数据(一般服从正态分布),生成图片,记作G(Z)。Z表示噪声数据。这些数据我们可以随机生成,一般符合高斯(正态分布)

判别网络

  1. 判断真实图像的输出结果
    输入为真实数据X,输出X为真实图片的概率(0-1),记作D(X)。为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
  2. 判断生成图像的输出结果
    输入为生成器生成的图片G(Z),输出为1或0(真或假),记作D(G(Z)

在这里插入图片描述

优化目标:

1.让生成器生成的图片G(Z)尽可能为真,骗过判别器。
2. 让判别器D提高精确度,将真实图片判为真,将生成的图片判为假

生成网络: l o g l o g ( D ( G ( z ) ) loglog(D(G(z)) loglog(D(G(z))越接近1越好,即生成的图片被判别为真实的
判别网络: l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x))+log(1-D(G(z))) log(D(x))+log(1D(G(z)))
loglog(D(G(z))越接近0越好,即判别器越强大能识别出生成的图片为假

上面这两个目标看似是矛盾的,这也解释了为什么叫做生成对抗网络。这样通过不断的进行多轮训练、“对抗|, 使得最后我们生成的图片可以“以假乱真”,通过判别网络判别为真。最终理想情况下, G 生成的数据与真实数据非常接近,分布也相同,而 D 无论输出真实数据还是 G 生成的数据都输出0.5。

下面以minist数据集进行训练

import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import os


if not os.path.exists('./img'):
    os.mkdir('./img')


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  # Clamp函数可以将随机变化的数值限制在一个给定的区间[min, max]内:
    out = out.view(-1, 1, 28, 28)
    return out


batch_size = 128
epochs = 100  #跑100轮
z_dimension = 100 #噪声数据的维度  输入一个100维的01之间的高斯分布

# 图形预处理
transform = transforms.Compose([
    transforms.ToTensor(), # 转换PIL.Image or numpy.ndarray
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 使用 mnist数据集
mnist = datasets.MNIST(root='./data/', train=True, transform=transform, download=True)


# 数据载入
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)



#定义判别器
#这里只是用的简单的几层网络 也可以用卷积神经网络convd进行判别
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator ,self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784 ,256)  ,  # minist数据集为28*28的灰度图,所以输入特征数为28*28*1=784,输出为256
            nn.LeakyReLU(0.2)  ,  # 进行非线性映射
            nn.Linear(256 ,256)  ,  # 进行一个线性映射
            nn.LeakyReLU(0.2),
            nn.Linear(256 ,1),
            nn.Sigmoid(  )  #激活函数,二分类问题中,将实数映射到[0,1],作为概率值,
            # 多分类用softmax函数
        )

    def forward(self, x):
        x=self.dis (x)
        return x


##定义生成器 Generator
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.gen=nn.Sequential(
            nn.Linear(100,256),# 将一个100维的01之间的高斯分布的噪声数据映射到256,
            nn.ReLU(True),  # relu激活
            nn.Linear(256,256),# 线性变换
            nn.ReLU(True),  # relu激活
            nn.Linear(256,784),# 线性变换
            nn.Tanh()# Tanh激活  使得生成数据分布在【-1,1】之间
        )

    def forward(self, x):
        x = self.gen(x)
        return x


# 实例化
D = discriminator()
G = generator()

#如果由英伟达的GPU 就用GPU处理,生成图像更快
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()

    ##判别器训练
    #两部分:1、真实图像
    #       2、生成的图像
    # 生成网络的参数不更新,始终用同一套参数

    # 定义损失函数loss(二分类的交叉熵)
    loss = nn.BCELoss()  # 是单目标二分类交叉熵函数
    optimizer_d = torch.optim.Adam(D.parameters(), lr=0.0003) #学习率选择0.0003
    optimizer_g = torch.optim.Adam(G.parameters(), lr=0.0003)

    #训练判别器

    for epoch in range(epochs):  # 进行100个epoch的训练
        for i, (img, _) in enumerate(dataloader):
            num_img = img.size(0)  #pytorch里特征的形式是[bs,channel,h,w],所以img.size(0)就是batchsize(每一个bath的数量)           
            img = img.view(num_img, -1)  # 将图片展开为28*28=784
            real_img = Variable(img).cuda()  # 将tensor包进Variable
            real_label = Variable(torch.ones(num_img)).cuda()  # 定义真实的图片label为1
            fake_label = Variable(torch.zeros(num_img)).cuda()  # 定义假的图片的label为0

			#判别真实图片
            real_out = D(real_img)
            d_loss_real = loss(real_out, real_label)  # 得到判别真实图片的loss
            real_scores = real_out  # 得到真实图片的判别值,输出的值越接近1越好,说明越判断正确了

            # 判别生成的假的图片
            z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 先随机生成一些高斯噪声数组
            fake_img = G(z)  # 随机噪声放入生成网络中,生成一张假的图片
            fake_out = D(fake_img)  # 判别器判断假的图片
            d_loss_fake = loss(fake_out, fake_label)  # 得到假的图片的loss
            fake_scores = fake_out  # 对于判别器来说,假图片的损失越接近0越好

            # 损失函数和优化
            d_loss = d_loss_real + d_loss_fake  # 损失包括两部分 log(D(x))+log(1-D(G(z)))
            optimizer_d.zero_grad()  # 在反向传播之前,先将梯度归0
            d_loss.backward()  # 将误差反向传播
            optimizer_d.step()  # 更新参数


            #训练生成网络
            # 将假的图片传入判别器的结果与真实的label对应, 反向传播更新的参数是生成网络里面的参数,这样更新生成网络里面的参数来训练网络,使得生成的图片让判别器判为真的
         

            # 计算假的图片的损失

            z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 得到随机噪声
            fake_img = G(z)  # 随机噪声丢入生成器中,生成假的图片
            output = D(fake_img)
            g_loss = loss(output, real_label)  # 假的图片与真实的图片的label的loss

            # 反向传播和参数更新
            optimizer_g.zero_grad()  # 梯度归0
            g_loss.backward()  
            optimizer_g.step()
            # 打印中间的损失
            if (i + 1) % 100 == 0:
                print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                      'D real: {:.6f},D fake: {:.6f}'.format(
                    epoch, epochs, d_loss.data[0], g_loss.data[0],
                    real_scores.data.mean(), fake_scores.data.mean()  # 打印的是真实图片的损失均值
                ))

            if epoch == 0:
                real_images = to_img(real_img.cpu().data)
                save_image(real_images, './img/real_images.png')

            fake_images = to_img(real_img.cpu().data)
            save_image(fake_images, './img/fake_images_{}.png'.format(epoch + 1))
    # 保存模型
    torch.save(G.state_dict(), './generator.pth')
    torch.save(D.state_dict(), './discriminator.pth')



  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值