20210813 -
0. 引言
最近在实现对抗自编码器的代码,想法是从最简单的模板开始。同时为了能够先找到点感觉,先看看怎么处理MNIST数据。
1. 代码示例
针对对抗自编码器的代码,找到了两份代码,分别是tensorflow实现和keras实现。其实最开始是弄的keras版本,但是判别器的判别准确率基本上一直稳定在100%,就挺奇怪的。所以,就有弄了个tensorflow来看看,不过这个问题还是没有解答。先把整理代码的过程来记录下,因为代码并不能直接跑。代码地址分别位于[1]和[2]。
1.1 Keras版本代码
这个版本的代码有一个错误,也不算错误把,属于API的版本问题。
def build_encoder(self):
# Encoder
img = Input(shape=self.img_shape)
h = Flatten()(img)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
latent_repr = Dense(self.latent_dim, activation='tanh')(h)
#mu = Dense(self.latent_dim)(h)
#log_var = Dense(self.latent_dim)(h)
#latent_repr = merge([mu, log_var], mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), output_shape=lambda p: p[0])
return Model(img, latent_repr)
他的代码部分,变量latent_repr
是由注释部分的代码来形成的,但是函数merge
在新版中已经不可用了,这段代码可以使用Lambda
层来实现。不过,在看了另外一篇文章[3]中,其指出,对于编码器的内容,可以通过3中方式来实现,最后一种不太明白,前面两种分别是决定性的(翻译是否正确有待商榷),或者类似变分自编码器的形式,将输出再链接到两个层,正是前面代码的注释部分。决定性,就是我代码中正使用的部分。
1.2 Tensorflow版本代码
首先要说明的是,tensorflow版本的代码中编码器是决定性的。但是这个代码是使用低版本tensorlofw写的(1.7.0好像是),在我2.3的环境上跑不起来。所以要代码进行一些修改。修改的部分有两个。
- example.tutorials
- 1.0api兼容性
这两个部分都可以在文章[4]中找到答案。因为使用了MNIST数据集,所以第一次运行的时候需要下载,这个下载过程,如果出现错误,可以多运行几次。
1.3 运行
通过上面的修改之后,两个版本的代码都能正常运行。不过其中他们的损失函数,我有些看不懂。其实还是训练过程中,对GAN的内在原理不是很清晰,还是需要在看看。
(未完待续,后续将记录实际的损失函数变化过程分析。。。)
引用
[1]Keras版本AAE
[2]Tensorflow版本AAE
[3]Adversarial Autoencoders
[4]Add missing example data to TensorFlow Version 2.0.0