生成对抗网络学习笔记

1.什么是GAN

GAN是一种深度神经网络架构,由一个生成网络与一个判别网络组成。生成网络产生假数据,并试图欺骗判别网络,而判别网络则对生成数据进行真伪鉴别,试图正确识别出所有假数据。在训练迭代过程中,两个网络持续进化与对抗,直到达到平衡,判别网络无法再识别假数据,训练结束

GAN模型主要包括两个部分:生成模型(Generative Model)与判别模型(Discriminative Model),也叫生成器(generator)与判别器(discriminator)

生成器学习的是数据的分布,从而让自身生成的数据更加真实,从而骗过判别器;判别器则是对接收的图片进行真假判别。

最终两个网络达到一个动态平滑:生成器生成的数据接近于真实数据分布,而判别器识别不出真假数据,对于给定数据的预测值为真的概率基本接近于0.5(相当于随机猜测类别)

如图,生成器生成猫的图片,判别器接收的其生成的数据以及真实数据,对二者进行判断。

生成器希望对生成样本的预测值为1,判别器希望对生成样本的结果判别为0

GAN的关键在于损失函数处理,对于判别模型:损失函数较易定义,判别器主要用于判断一张图片是真实的还是生成的,是个二分类问题。对于生成模型,损失函数并不容易,生成器希望生成接近真实的数据,对于生成的数据是否真实,难以用数学公理化定义的范式,因此,可以将生成器的输出,交给判别器,令判别器判断该数据是真是假,如此就将生成器与判别器联合了起来

2.GAN算法流程及公式

以生成图片为例,假设有两个网络G(Generator)与D(Discriminator)

G是个生成图片的网络,其接收一个随机噪声z,通过该噪声生成图片,记为G(z)

D是个判别网络,判断一个图片是否真实,其输入参数为x,x代表一个图片,输出D(x)代表x是真实图片的概率,若为1,则100%为真实图片,输出0,代表不可能是真实图片

训练流程:将随机噪声输入到生成网络G,得到生成的图片,判别器接收生成的图片与真实的图片,并尽量将二者区分,在这个过程中,是否能够正确区分生成图片与真实图片将作为判别器的损失,而能否生成近似真实的图片并使判别器将生成的图片判别为真将作为生成器的损失。

注:生成器的损失是通过判别器的输出计算的,而判别器输出的是概率值,可通过交叉熵损失函数计算

GAN的公式:

其中x~data代表从真实图片中抽取的一张图片,z为随机分布的噪声,G(z)代表根据噪声生成的图片,D(G(z))则是对生成的图片进行判别。

公式的前半部分D(x)代表判别器对于真实图片的判断,希望其判断为1,则D(x)结果越接近1,则log(D(x))计算的损失值越小,相反,对其如果判断为0,则会约小(为负,log0=负无穷),从而放大损失

公式的后半部分D(G(z))代表判别器对从噪声中生成图片真伪的判别结果,对于判别器,若输入是生成的数据,则希望对其判断结果为0,希望D(G(Z))的结果越小越好,而为了统一两个模型都是越大越好,因此用1-D(G(z)),代表越大越好;而对于生成器,他希望D(G(z))的结果接近于1,越大越好,所以希望(1-D(G(z)))越小越好

使用Log是为了放大损失

综上,从判别器的角度希望最大化V(D,G),从生成器角度,希望最小化V(D,G)

3.基础GAN代码实现

我希望通过前10个数据预测下一个数据,因此一条数据长度为11,数据集是个n*11的矩阵

步骤1:将数据集归一化至(-1, 1)

    # 1.归一化到[-1, 1]之间
    min_max_scaler = preprocessing.MinMaxScaler(feature_range=(-1, 1))
    new_data = min_max_scaler.fit_transform(data)

步骤二:实现生成器代码

"""将噪声序列转化为目标序列"""
import torch
import torch.nn as nn


class Generator(nn.Module):
    def __init__(self, in_size, out_size):
        # 噪声序列有inpuut_size个数据,要生成的是out_size长度的序列
        super(Generator, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(in_size, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, out_size),
            nn.Tanh()
        )

    def forward(self, x):
        """输入的x是[batch_size, 1, in_size]"""
        out = self.linear(x)
        return out

 输入是个长度为in_size的噪声序列,我们需要根据这个序列输出长度为out_size的序列,这个out_size就是前面数据集所需要的一条数据的长度(11)

这个模型此处结构简单,3个linear,需注意的是最后一个激活函数最好使用Tanh,而不是ReLu

步骤三:实现判别器代码

判别器实际上是个二分类模型,对于我们的任务,就是判断一个序列是真实数据还是生成数据

在该模型中,使用sigmoid函数进行激活,且每一层之间的激活函数用LeakyReLu,其作用是当值<0时,输出a*x,a是个很小的斜率值

注:损失计算采用BECloss计算交叉熵损失

"""输入的是个序列(1, len),输出为其二分类的概率值,使用sigmoid激活0-1
用BECloss计算交叉熵损失,
激活函数LeakyReLu:>0 输出x, <0输出 a*x, a代表一个小斜率值"""
import torch
import torch.nn as nn


class Discriminator(nn.Module):
    def __init__(self, in_size):
        """in_size是输入序列的长度"""
        super(Discriminator, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(in_size, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.linear(x)
        return out


if __name__ == '__main__':
    x = torch.randn(size=(20, 11))
    model = Discriminator(in_size=11)
    out = model(x)
    print(out.shape)

步骤四:初始化模型、优化器、计算损失

    # 2.定义模型
    gen = Generator(in_size=10, out_size=one_data_size)  # 输入的十个长度为10的噪声,输出是一个数据的长度
    dis = Discriminator(in_size=one_data_size)  # 输入一条数据,输出其为真实数据的概率
    # 3.定义优化器
    d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
    g_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
    loss = torch.nn.BCELoss()  # 损失

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值