keras 的 example 文件 mnist_acgan.py 解析

这是一个gan网络,大致分为两个神经网络,一个是生成网络,另一个是判别网络

判别网络的结构大致如下:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 28, 28, 1)    0
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 12544)        387840      input_1[0][0]
__________________________________________________________________________________________________
generation (Dense)              (None, 1)            12545       sequential_1[1][0]
__________________________________________________________________________________________________
auxiliary (Dense)               (None, 10)           125450      sequential_1[1][0]
==================================================================================================
Total params: 525,835
Trainable params: 525,835
Non-trainable params: 0
__________________________________________________________________________________________________

其中 Sequential1 的网络结构为:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 14, 14, 32)        320
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 32)        0
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 32)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 64)        0
_________________________________________________________________
dropout_2 (Dropout)          (None, 14, 14, 64)        0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 128)         73856
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 7, 7, 128)         0
_________________________________________________________________
dropout_3 (Dropout)          (None, 7, 7, 128)         0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 7, 7, 256)         295168
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 7, 7, 256)         0
_________________________________________________________________
dropout_4 (Dropout)          (None, 7, 7, 256)         0
_________________________________________________________________
flatten_1 (Flatten)          (None, 12544)             0
=================================================================
Total params: 387,840
Trainable params: 387,840
Non-trainable params: 0
_________________________________________________________________

就是跟定一张图片,通过一堆卷积、激活、dropout之后,最后拉伸生成一个12544维度的一个向量,然后跟两个Dense,一个是判断是否为真图片(generation ),另一个是判断是哪个数字(auxiliary)

生成网络的结构大致如下:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_3 (InputLayer)            (None, 1)            0
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 100)          0
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 100)       1000        input_3[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, 1, 100)       0           input_2[0][0]
                                                                 embedding_1[0][0]
__________________________________________________________________________________________________
sequential_2 (Sequential)       (None, 28, 28, 1)    2656897     multiply_1[0][0]
==================================================================================================
Total params: 2,657,897
Trainable params: 2,657,321
Non-trainable params: 576
__________________________________________________________________________________________________

其中Sequential1的网络结构为:

____________________________________________________________________________________________________
Layer (type)                                 Output Shape                            Param #
====================================================================================================
dense_1 (Dense)                              (None, 3456)                            349056
____________________________________________________________________________________________________
reshape_1 (Reshape)                          (None, 3, 3, 384)                       0
____________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTranspose)         (None, 7, 7, 192)                       1843392
____________________________________________________________________________________________________
batch_normalization_1 (BatchNormalization)   (None, 7, 7, 192)                       768
____________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTranspose)         (None, 14, 14, 96)                      460896
____________________________________________________________________________________________________
batch_normalization_2 (BatchNormalization)   (None, 14, 14, 96)                      384
____________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTranspose)         (None, 28, 28, 1)                       2401
====================================================================================================
Total params: 2,656,897
Trainable params: 2,656,321
Non-trainable params: 576
____________________________________________________________________________________________________

也就是有两个输入,一个是随机数(input_2),另一个是类别(input_3),就是数字几

其中输入 input_3 经过一个Embedding 之后和 和 input_2 相乘,这里是一个点乘,也叫内积,相乘之后shape不变,生成一个100维的向量,再经过Dense、Reshape 和 Conv2DTranspose 之后,生成一张28*28的黑白图片

 

上面生成网络和判别网络合并起来,大致结构为:

________________________________________________________________________________________________________________________
Layer (type)                           Output Shape               Param #       Connected to
========================================================================================================================
input_4 (InputLayer)                   (None, 100)                0
________________________________________________________________________________________________________________________
input_5 (InputLayer)                   (None, 1)                  0
________________________________________________________________________________________________________________________
model_2 (Model)                        (None, 28, 28, 1)          2657897       input_4[0][0]
                                                                                input_5[0][0]
________________________________________________________________________________________________________________________
model_1 (Model)                        [(None, 1), (None, 10)]    525835        model_2[1][0]
========================================================================================================================
Total params: 3,183,732
Trainable params: 2,657,321
Non-trainable params: 526,411
________________________________________________________________________________________________________________________

 

这里有一个 train_on_batch 加上参数 sample_weight ,这个sample_weight 是对应 [y, aux_y] ,

            print(len(disc_sample_weight))
            print(len(disc_sample_weight[0]))
            print(len(disc_sample_weight[1]))

            tmp = [y, aux_y]
            print(len(tmp))
            print(len(tmp[0]))
            print(len(tmp[1]))

大致就是这么个意思,y,也就是是否为真实,这个计算损失的结果就正常计算,稍微有一点就是真实图片的y的 label 值为 0.95

aux_y的损失,由于对于新生成的图片,计算其分类没有啥意义,所以最初是把它的损失结果直接乘以0,而对于mnist库中的图片,把分类的损失乘以2,弥补一下

这种情况下,我们训练判别网络 discriminator 一次

然后我们再生成一堆图片,然后把是否为真图片的标签,全部设置为0.95,然后训练一次 combined 网络,该网络中 discriminator.trainable = False,所以这里仅训练了生成网络

训练过程基本就是这些,其他代码就是计算测试的损失和保存生成图片

如下图,效果不错

 

——————————————————————

总目录

keras的example文件解析

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值