ACGAN 生成自己手写数字数据集


前言

由于有可能使用GAN 网络来做一些数据增强,所以这里复现一下GAN 网络,发现这玩意儿还挺好玩。

一、GAN是什么?

GAN (Generative Adversarial Networks)生成对抗网络,用来生成一下不存在的真实数据。应用场景如下:
1.风格迁移:也就是传说中的AI 画家
2.图像超分辨率重建: 让图像更加清晰
3.生成不存在的真实数据:人脸生成等~
在这里插入图片描述
根据训练时带不带标签,GAN 网络是可分为无监督和半监督式的网络。GAN
网络分为两部分,Generator (生成器,图中G)和 Discriminator (判别器,图中D)…
随机生成的噪声,通过生成器,生成我们想要的数据,然后把这个数据和真实数据一起送入到判别器中判断,如果判别器认为输入的是生成数据,那么久训练判别器,如果判别器把生成的数据认为是真的数据,那么就要训练判别器啦~,生成器与判别器两者之间相互博弈,最后让生成器能够成功的欺骗过判别器,那么就可以使用生成器来生成想要的数据啦。

根据前人经验,生成器中的激活函数一般用relu。判别器中的激活函数一般用LeakyReLU

二、ACGAN

1.ACGAN 网络结构

由于ACCGAN 是带有标签的GAN 如果训练得当,应该可以生成想要的数据。看看它的网络结构:在这里插入图片描述

图中,输入到 生成器中的标签 C 和 Z 是随机生成的,但一般都要符合正态分布,生成器生成的假数据,将和真实数据一起输入到判别器中进行判断,真实数据的label 将和判别器输出的label 做损失计算,另一端的输出,只需要判断真假就好。

2.Generator 生成器实现

代码如下:

    def built_generator(self):
        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation='relu', input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(BatchNormalization(momentum=0.8))

        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        # model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(self.channels, kernel_size=3, padding='same', activation='tanh'))

        model.summary()

        # -----------------
        # 生成噪声
        # -----------------、
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
        # print(Embedding(self.num_classes, self.latent_dim)(label).shape)
        model_input = multiply([noise, label_embedding])

        img = model(model_input)

        return Model([noise, label], img)

关于生成器中的参数设置,首先是全连接 7x7x128, 由于手写数字 图片大小为28x28,初始大小设为7x7 后续会通过2次上采样,就会变成14x14 再由14x14 变为28x28 ,还原图片的大小。

注意:如果要训练自己的图片数据,记得计算好图片大小和上采样的次数,每次上采样,特征图会扩大到原来的两倍

3.Discriminator 判别器实现

    def built_discriminator(self):
        model = Sequential()

        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))

        model
  • 5
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

__不想写代码__

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值