自编码网络(五)——变分自编码

变分自编码其实就是在编码过程中改变样本的分布,“学习样本的规律”,具体指的就是样本的分布,假设知道样本的分布函数,就可以从这些函数中随便取一个样本,然后进行网络解码层前向传导,这样就可以生成一个新的样本。

为了得到这这个样本的分布函数,模型训练的目的将不再是这样的样本本身,而是通过加一个约束项,将网络生成一个服从高斯分布的数据集,这样按照高斯分布里的均值和方差规则就可以取任意相关的数据,然后通过解码层还原成样本。

使用自编码模型进行模拟MNIST数据的生成

1.引入库,定义占位符

编码器为两个全连接层,第一个全连接层有784个维度的输入变化256个维度的输出;第二个全连接层并列连结了两个输出网络,每个网络都输出了两个维度的输出。然后将两个输出通过一个公式计算,输入到一个2节点为开始的解码部分,接着后面为两个全连接层的解码器,第一层由两个维度输入256个维度的输出,第二层有256个维度的输入到784个维度的输出。如下图所示。

在下面的代码中,引入了一个scipy库,后面的可视化会用到。头文件引入之后,定义操作符x和z.x用原始的图片输入,z用于中间节点解码器的输入。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/")


n_input = 784
n_hidden_1 = 256
n_hidden_2 = 2

x = tf.placeholder(tf.float32, [None,  n_input])

zinput = tf.placeholder(tf.float32, [None, n_hidden_2])

2.定义学习参数

mean_w1与mean_b1生成的是mean的权重,log_sigma_w1与log_sigma_b1生成的是log_sigma的权重。

这里初始化w的权重与之前不同,使用了很小的值(方差为0.001的truncated_nomal)。这里设置的非常小心,由于在计算KL离散度时计算的是与标准高斯分布的距离,如果网络初始生成的模型均值方差都很大,那么与标准高斯分布距离就会非常大,这样会导致模型训练不出来,成成NAN的情况。

weights = {

    'w1':tf.Variable(tf.truncated_normal([n_input , n_hidden_1 ], stddev = 0.001)),

    'b1':tf.Variable(tf.zeros([n_hidden_1])),

    'mean_w1':tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2], stddev = 0.001)),

    'log_sigma_w1':tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2], stddev = 0.001)),

    'w2':tf.Variable(tf.truncated_normal([n_hidden_2, n_hidden_1], stddev = 0.001)),

    'b2':tf.Variable(tf.zeros([n_hidden_1])),

    'w3':tf.Variable(tf.truncated_normal([n_hidden_1, n_input], stddev = 0.001)),

    'b3':tf.Variable(tf.zeros([n_input])),

    'mean_b1':tf.Variable(tf.zeros([n_hidden_2])),

    'log_sigma_b1':tf.Variable(tf.zeros([n_hidden_2]))

}

3.定义网络结构

根据上图的的描述,网络节点可以按照以下代码来定义,在变分解码器中为训练的中间节点赋予了特殊的意义,让他们代表均值和方差,并将它们所代表的数据集向着标准高斯分布数据集靠近(也就是原始数据是样本,高斯分布数据是标签),然后可以使用KL离散度公式,来计算它所代表的集合与标准高斯分布集合(均值是0,方差为1的正态分布)间的距离,将这个距离当成误差,让它最小化从而优化网络参数。

这里的方差节点不是真正意义的方差,是取了log之后的,所以会有tf.exp(z_logma_sq)的变换,是取得方差的值,再通过tf.sqrt将其开平方得到标准差。用符合标准正态分布的一个数乘以标准差加上均值,就使这个数成为符合(z_mean,sigma)数据分布集合里面的一个点(z_mean是指网络生成的值,sigma是指网络生成的z_log-sigma_sq变换后的值)。

h1 = tf.nn.relu(tf.add(tf.matmul(x, weights['w1']), weights['b1']))
z_mean = tf.add(tf.matmul(h1, weights['mean_w1']), weights['mean_b1'])
z_log_sigma_sq = tf.add(tf.matmul(h1, weights['log_sigma_w1']), weights['log_sigma_b1'])

#高斯分布样本
eps = tf.random_normal(tf.stack([tf.shape(h1)[0], n_hidden_2]), 0, 1, dtype = tf.float32)
z = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)), eps))
h2 = tf.nn.relu(tf.matmul(z, weights['w2'])+weights['b2'])
reconstruction = tf.matmul(h2, weights['w3'])+weights['b3']

h2out = tf.nn.relu(tf.matmul(zinput, weights['w2'])+weights['b2'])
reconstructionout  = tf.matmul(h2out, weights['w3'])+weights['b3']

4.构建模型和反向传播

需要定义损失函数的节点和优化算法OP。

#计算loss
reconstr_loss = 0.5*tf.reduce_sum(tf.pow(tf.subtract(reconstruction, x), 2.0))
latent_loss = -0.5*tf.reduce_sum(1+z_log_sigma_sq
                                 -tf.square(z_mean)
                                 -tf.exp(z_log_sigma_sq), 1)
cost = tf.reduce_mean(reconstr_loss+latent_loss)
optimizer = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(cost)

上述代码优化了网络的两个优化方向:

①一个是比较生成数据分布于标准高斯分布的距离,这里使用KL离散度的公式 。

②另一个是计算生成数据与原始数据之间的损失,这里用的是平方差,也可以用交叉熵。

最后将两种损失值放在一起,通过Adam的随机梯度下降算法实现训练中的优化参数。

5.参数设置,进行训练

设置训练参数,迭代50次,在session中每次循环取指定批次数据进行训练。

training_epochs = 50
batch_size = 128
display_step = 3

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)

        #遍历全部数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            print(batch_xs)

            #输入数据
            _,c = sess.run([optimizer, cost], feed_dict = {x: batch_xs})
        #显示训练中的详细信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d'% (epoch +1), "cost = ", "{:.9f}".format(c))
    print("完成")

    #测试
    print("Result:",cost.eval({x:mnist.test.images}))

输出图片如下图所示:

6.可视化部分

    #可视化结果
    show_num = 10
    pred = sess.run(reconstruction , feed_dict = {x: mnist.test.images[:show_num]})

    f, a = plt.subplots(2, 10, figsize = (10, 2))
    for i in range(show_num):
        a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
        a[1][i].imshow(np.reshape(pred[i], (28, 28)))
    plt.draw()

    pred = sess.run(
        z, feed_dict={x: mnist.test.images})
    # x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
    plt.figure(figsize=(6, 6))
    plt.scatter(pred[:, 0], pred[:, 1], c=mnist.test.labels)
    plt.colorbar()
    plt.show()

 执行上述代码,如下图所示:

可以看到,第一行代表原始的样本图片,第二行代表变分自编码重建后生成的图片。可以见到生成的数字中不再一味单纯地学习形状,而是通过数据分布的方式学习规则,对原来的图片有更清晰的修正功能。

根据可视化代码,将均值和方差代表的二维数据在直角坐标系中展现如下:

                                 

可以看出,MNIST数据集同一类样本的特征分布还是比较集中的,说明变分自解码也具有降维功能,也可以用它进行分类任务的数据降维处理。

7.高斯分布取样,生成模拟数据

为了进一步证实模型学到的数据分布情况,这次在高斯分布抽样中取出一些点,将其映射到模型中的z,然后通过解码部分还原成真实的图片效果,代码如下:

    # display a 2D manifold of the digits
    n = 15  # figure with 15x15 digits
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
    grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z_sample = np.array([[xi, yi]])
            x_decoded = sess.run(reconstructionout, feed_dict={zinput: z_sample})

            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
            j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='Greys_r')
    plt.show()

执行上面代码,生成如下:

 可以看到,从左下角到右上角显示了网络是按照图片的形状变化而排列的。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值