你了解变分自编码器吗? 请看这里

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_40652148/article/details/80467662
10.9  变分自编码器
前面所描述的自编码器可以降维重构样本,在这基础上我们来学习一个更强大的自编码器。
10.9.1  什么是变分自编码器
变分自编码器学习的不再是样本的个体,而是要学习样本的规律。这样训练出来的自编码器不单单具有重构样本的功能,还具有了仿照样本的功能。
听起来这么强大的功能,到底是怎么做到的?下面我们来讲讲它的原理。
变分自编码器,其实就是在编码过程中改变了样本的分布(“变分”可以理解为改变分布)。前文中所说的“学习样本的规律”,具体指的就是样本的分布,假设我们知道样本的分布函数,就可以从这个函数中随便取一个样本,然后进行网络解码层前向传导,就可以生成了一个新的样本。
为了得到这个样本的分布函数,我们的模型训练目的将不再是样本本身,而是通过加一个约束项,将我们的网络生成一个服从于高斯分布的数据集,这样按照高斯分布里的均值和方差规则可以任意取相关的数据,然后通过解码层还原成样本。
10.9.2  实例82:使用变分自解码器模拟生成MNIST数据
对于变分自解码器,好多文献都是给出了一堆晦涩难懂的公式,其实里面真正的公式只有一个——KL离散度的计算。而它也属于成熟的式子,就跟交叉熵一样,直接拿来用就可以。
公式本来是语言的高度概括,而一篇文章全是公式没有语言就会令人难以理解。本文只会有代码加上语言描述,不会让这部分知识读起来感觉晦涩。
代码例子共分如下几个步骤,下面我们就来一一操作。
案例描述
使用变分自编码模型进行模拟MNIST数据的生成。
1.引入库,定义占位符

这次建立的网络与以前略有不同,编码为两个全连接层由784到256再到两个2层的并列输出,然后将两个输出通过一个公式的计算,输入到以一个2节点为开始的解码部分,接着2个全连接层又2到256再到784。如图10-17


                                             图10-17 变分解码器层次
具体的计算公式,后文会有详细介绍。
在下面的代码中与前面代码不同,下面引入了一个scipy库,在后面可视化时会用到。头文件引入之后,定义操作符x和z。x用于原始的图片输入,z用于中间节点解码器的输入。

代码10-8  变分自编码器


zinput是个占位符,在后面要通过它将分布数据输入,用来生成模拟样本数据。
2.定义学习参数
由于这次的网络结构不同,所以定义的参数也有变化,mean_w1与mean_b1是生成mean的权重,log_sigma_w1与log_sigma_b1是生成log_sigma的权重。

代码10-8  变分自编码器(续)




3.定义网络结构
按照上面图10-16的描述,网络节点可以按照以下代码来定义,在变分解码器为训练的中间节点赋予了特殊的意义,让它们代表均值和方差,并将他们所代表的数据集向着标准高斯分布数据集靠近(也就是原始数据是样本,高斯分布数据是标签),然后可以使用kl散度公式,来计算它所代表的集合与标准的高斯分布集合(均值是0,方差为1的正态分布)间的距离,将这个距离当成误差让它最小化从而来优化网络参数。
这里的方差节点不是真正意义的方差,是取了log之后的。所以会有tf.exp(z_log_sigma_sq)的变换,是取得方差的值,再tf.sqrt将其开平方得到标准差。用符合标准正太分布的一个数来乘上标准差加上均值,就使这个数成为符合(z_mean,sigma)数据分布集合里面的一个点(z_mean是指网络生成均值,sigma是指网络生成的z_log_sigma_sq变换后的值)。


到此,完成了编码阶段。将原始数据编码输出3个值:
● 一个是该表述数据分布的均值,
● 一个是表述该数据分布的方差,
● 还有一个是得到了该数据分布中的一个实际的点z。


代码10-8  变分自编码器(续)


得到了符合原数据集上的一个具体点z之后,就可以通过神经网络这个点z还原成原始数据reconstruction了。这个解码部分还是和以前的内容一样,参照编码的网络逐层还原回去。
h2out和reconstructionout两个节点不属于训练中的结构,是为了生成指定数据时用的。
4.构建模型的反向传播
和以往一样,需要定义损失函数的节点和优化算法的op,代码如下。

代码10-8  变分自编码器(续)


上面代码描述了网络两个优化方向:
● 一个是比较生成的数据分布与标准高斯分布的距离,这里使用KL离散度的公式(见latent_loss)。
● 另一个是计算生成数据与原始数据间的损失,这里用的是平方差,也可以用交叉熵。
最后将两种损失值放在一起,通过adam的随机梯度下降算法来实现在训练中的优化参数。
5.设置参数,进行训练
这步骤与前面类似,设置训练参数,迭代50次,在session中每次循环取指定批次数据进行训练。

代码10-8  变分自编码器(续)


可视化部分这里不再详述,可以参考本书的配套代码,最终程序运行的结果输出如下,结果如图10-18所示。



可以看到生成的数字,不再一味单纯的学习形状,而是通过数据分布的方式学习规则,对原有图片具有更清晰的修正功能。

仿照前面的可视化代码,将均值和方差代表的二维数据在直角坐标系中展现如下:


 
                                                           图10-19变分自解码二维可视化
从图10-19中可以看出,具有代表同一数值的图片的特征数据分布还是比较集中的,说明变分字节码也具有降维功能,也可以用它进行分类任务的数据降维预处理部分。
6.高斯分布取样,生成模拟数据
为了进一步证实模型学到数据分布的情况,我们这次在高斯分布中抽样去取一些点,将其映射到模型中的z,然后通过解码部分还原成真实图片看看效果,代码如下。

注意:


代码10-8  变分自编码器(续)


运行以上代码生成如图10-20所示图片。


                                              图10-20 变分自解码生成模拟数据

可以看到,在神经网络的世界里,所以左下角到右上角显示了网络是按照图片的形状变化而排列的,并不像我们人类一样,把数字按照1到9的排列,因为机器学的只是图片,而人类对数字的理解更多的是在于它幕后的意思。

更多章节请购买《深入学习之 TensorFlow 入门、原理与进阶实战》全本



展开阅读全文

没有更多推荐了,返回首页