变分自编码网络的实现

1、VAE跟Gan有点类似,都是可以通过一些输入,生成一些样本数据。不同点是VAE是假设在数据的分布是服从正态分布的,而GAN是没有这个假设的,完全是由数据驱动,进行训练得出规律的。

下面是变分自编码网络的代码:

import numpy as np
import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.contrib.layers import fully_connected
import tensorflow.examples.tutorials.mnist as mnist
import functiontool as functiontool

# 定义一些全局变量
n_inputs = 28 * 28
n_hidden1 = 500
n_hidden2 = 500
n_hiddenmiddle = 30
n_hidden3 = n_hidden2
n_hidden4 = n_hidden1
n_outputs = n_inputs
learning_rate = 0.001
Minst = mnist.input_data.read_data_sets("MNIST_data/")

# 定义网络的结构
with contrib.framework.arg_scope([fully_connected], activation_fn=tf.nn.elu, weights_initializer=
contrib.layers.variance_scaling_initializer()):
    X = tf.placeholder(dtype=tf.float32, shape=[None, n_inputs])
    hidden1 = fully_connected(X, n_hidden1)
    hidden2 = fully_connected(hidden1, n_hidden2)
    hiddenmiddle_mean = fully_connected(hidden2, n_hiddenmiddle, activation_fn=None)
    hiddenMiddle_gamma = fully_connected(hidden2, n_hiddenmiddle, activation_fn=None)
    hiddenMiddel_sigmar = tf.exp(0.5 * hiddenMiddle_gamma)
    noise = tf.random_normal(tf.shape(hiddenMiddel_sigmar))
    hiddemiddle = hiddenmiddle_mean + hiddenMiddel_sigmar * noise
    hidden3 = fully_connected(hiddemiddle, n_hidden3)
    hidden4 = fully_connected(hidden3, n_hidden4)
    logits = fully_connected(hidden4, n_outputs, activation_fn=None)
    outputs = tf.sigmoid(logits)
# 定义损失函数
restruction_loss =tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=X, logits=logits))
latent_loss = 0.5 * tf.reduce_sum(tf.exp(hiddenMiddle_gamma) + tf.square(hiddenmiddle_mean) - 1 - hiddenMiddle_gamma)
sum_loss = restruction_loss + latent_loss
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_optimizer = optimizer.minimize(sum_loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
# 定义网络的训练
n_epochs = 60
n_batch = 150
with tf.Session() as session:
    init.run()
    for i in range(n_epochs):
        batch_nums = Minst.train.num_examples // n_batch
        for batch_size in range(batch_nums):
            print("\r{}%".format(100 * batch_size // batch_nums), end="")
            X_trian, Y_train = Minst.train.next_batch(n_batch)
            session.run(train_optimizer, feed_dict={X: X_trian})
        loss_val = sum_loss.eval(feed_dict={X: X_trian})
        print("\rTrain loss:{}".format(loss_val))
        saver.save(session, "weight/VaAuto.cpkt")
    test_rng = np.random.normal(size=(10, n_hiddenmiddle))
    out_val = outputs.eval(feed_dict={hiddemiddle: test_rng})
    functiontool.show_reconstructed_digits_old(out_val)

其画图的函数为:


def show_reconstructed_digits_old(outputs):
    dimsize = outputs.shape[0]
    plt.figure(figsize=(8, 50))
    for i in range(outputs.shape[0]):
        plt.subplot(outputs.shape[0], 1, i + 1)
        plot_image(outputs[i])
    plt.show()

得出的训练结果是:

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值