这是一个gan网络,大致分为两个神经网络,一个是生成网络,另一个是判别网络
判别网络的结构大致如下:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 28, 28, 1) 0
__________________________________________________________________________________________________
sequential_1 (Sequential) (None, 12544) 387840 input_1[0][0]
__________________________________________________________________________________________________
generation (Dense) (None, 1) 12545 sequential_1[1][0]
__________________________________________________________________________________________________
auxiliary (Dense) (None, 10) 125450 sequential_1[1][0]
==================================================================================================
Total params: 525,835
Trainable params: 525,835
Non-trainable params: 0
__________________________________________________________________________________________________
其中 Sequential1 的网络结构为:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 14, 14, 32) 320
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 14, 14, 32) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 14, 14, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 14, 14, 64) 18496
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 14, 14, 64) 0
_________________________________________________________________
dropout_2 (Dropout) (None, 14, 14, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 7, 7, 128) 73856
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 7, 7, 128) 0
_________________________________________________________________
dropout_3 (Dropout) (None, 7, 7, 128) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 7, 7, 256) 295168
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 7, 7, 256) 0
_________________________________________________________________
dropout_4 (Dropout) (None, 7, 7, 256) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 12544) 0
=================================================================
Total params: 387,840
Trainable params: 387,840
Non-trainable params: 0
_________________________________________________________________
就是跟定一张图片,通过一堆卷积、激活、dropout之后,最后拉伸生成一个12544维度的一个向量,然后跟两个Dense,一个是判断是否为真图片(generation ),另一个是判断是哪个数字(auxiliary)
生成网络的结构大致如下:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_3 (InputLayer) (None, 1) 0
__________________________________________________________________________________________________
input_2 (InputLayer) (None, 100) 0
__________________________________________________________________________________________________
embedding_1 (Embedding) (None, 1, 100) 1000 input_3[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply) (None, 1, 100) 0 input_2[0][0]
embedding_1[0][0]
__________________________________________________________________________________________________
sequential_2 (Sequential) (None, 28, 28, 1) 2656897 multiply_1[0][0]
==================================================================================================
Total params: 2,657,897
Trainable params: 2,657,321
Non-trainable params: 576
__________________________________________________________________________________________________
其中Sequential1的网络结构为:
____________________________________________________________________________________________________
Layer (type) Output Shape Param #
====================================================================================================
dense_1 (Dense) (None, 3456) 349056
____________________________________________________________________________________________________
reshape_1 (Reshape) (None, 3, 3, 384) 0
____________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTranspose) (None, 7, 7, 192) 1843392
____________________________________________________________________________________________________
batch_normalization_1 (BatchNormalization) (None, 7, 7, 192) 768
____________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTranspose) (None, 14, 14, 96) 460896
____________________________________________________________________________________________________
batch_normalization_2 (BatchNormalization) (None, 14, 14, 96) 384
____________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTranspose) (None, 28, 28, 1) 2401
====================================================================================================
Total params: 2,656,897
Trainable params: 2,656,321
Non-trainable params: 576
____________________________________________________________________________________________________
也就是有两个输入,一个是随机数(input_2),另一个是类别(input_3),就是数字几
其中输入 input_3 经过一个Embedding 之后和 和 input_2 相乘,这里是一个点乘,也叫内积,相乘之后shape不变,生成一个100维的向量,再经过Dense、Reshape 和 Conv2DTranspose 之后,生成一张28*28的黑白图片
上面生成网络和判别网络合并起来,大致结构为:
________________________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
========================================================================================================================
input_4 (InputLayer) (None, 100) 0
________________________________________________________________________________________________________________________
input_5 (InputLayer) (None, 1) 0
________________________________________________________________________________________________________________________
model_2 (Model) (None, 28, 28, 1) 2657897 input_4[0][0]
input_5[0][0]
________________________________________________________________________________________________________________________
model_1 (Model) [(None, 1), (None, 10)] 525835 model_2[1][0]
========================================================================================================================
Total params: 3,183,732
Trainable params: 2,657,321
Non-trainable params: 526,411
________________________________________________________________________________________________________________________
这里有一个 train_on_batch 加上参数 sample_weight ,这个sample_weight 是对应 [y, aux_y] ,
print(len(disc_sample_weight))
print(len(disc_sample_weight[0]))
print(len(disc_sample_weight[1]))
tmp = [y, aux_y]
print(len(tmp))
print(len(tmp[0]))
print(len(tmp[1]))
大致就是这么个意思,y,也就是是否为真实,这个计算损失的结果就正常计算,稍微有一点就是真实图片的y的 label 值为 0.95
aux_y的损失,由于对于新生成的图片,计算其分类没有啥意义,所以最初是把它的损失结果直接乘以0,而对于mnist库中的图片,把分类的损失乘以2,弥补一下
这种情况下,我们训练判别网络 discriminator 一次
然后我们再生成一堆图片,然后把是否为真图片的标签,全部设置为0.95,然后训练一次 combined 网络,该网络中 discriminator.trainable = False,所以这里仅训练了生成网络
训练过程基本就是这些,其他代码就是计算测试的损失和保存生成图片
如下图,效果不错
——————————————————————
总目录