生成对抗网络(二)CGAN

一、简介

        之前介绍了生成式对抗网络(GAN),关于GAN的变种比较多,我打算将几种常见的GAN做一个总结,也算是激励自己学习,分享自己的一些看法和见解。

        之前提到的GAN是最基本的模型,我们的输入是随机噪声,输出的是对应的图像,但是我们没法控制生成图像的类型。比如,我要生成一张数字0的图片,但是GAN生成的图片却是数字0-9的图片,针对这个问题,Conditional Generative Adversarial Nets被提了出来,在原有GAN的基础上,添加了类别信息以便让模型生成特定的图片。这里的条件(conditional),就是这个额外的类别信息。

二、原理

         由于在GAN的生成器和判别器中都加入了额外的类别信息,模型的目标优化函数也发生了变化。

         生成器的输入变为噪音变量P_z(z)和类别信息y, 判别器的输入为图片数据x和类别信息y, 目标函数如下:

               \underset{G}{min}\, \underset{D}{max}\, V(D,G)=E_{x\sim p_{data} (x)}[\log D(x|y)] +E_{z\sim P_{z}(z)}[\log (1-D(G(z|y)))]

         就是在GAN的目标函数上添加了y这一类别变量,x变为了条件分布。

         模型的结构图如下,

         

        GAN的结构与这个类似,生成器部分和判别器部分是分开的两个子网络,单独进行训练。类别信息y是通过embedding层嵌入的。

        具体的实现可以看看代码:

        生成器:

    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

          标签是通过嵌入层实现的,embedding层可以将类别标签转换为对应的向量表示,在此生成器中,类别有10个(0-9),对应embedding中的input_dim, 输出维度和噪音数据是相同的,之后,再利用multiply层将两者逐项做乘积,这便是生成器的输入。

           判别器:

    def build_discriminator(self):

        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.img_shape)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)

        model_input = multiply([flat_img, label_embedding])

        validity = model(model_input)

        return Model([img, label], validity)

            判别器的输入和生成器是一样的,输出是对应的图片的类别。

            训练:

            训练采用的mnist数据集,训练时需要将图片数据和对应的标签输入模型。

            生成器和判决器作为一个整体进行训练的时候,判别器是不训练的,这时只训练生成器;当判决器作为一个单独的模型时,判决器会得到训练。二者的训练是交替进行的。

            具体的代码可以参考github

三、效果

       最后跑出来的效果还是很不错的,我在台式机上跑的,用的是1050ti的显卡,训练速度还比较快,一共20000轮,大概10分钟左右跑完。

       这是最后的训练效果:

       

       可以与前一篇博客里面的内容进行比较,与原始的GAN相比,效果要好一些,但是还是不是很清晰。一方面,mnist提供的图片像素较低,另一方面,我们采用的是全连接神经网络,对于图片的处理效果并不是很好。

       要生成更加清晰地图片,可以利用DCGAN,这也是我接下来要做的工作。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值