tensorflow1.1/variational_autoencoder

环境tensorflow1.1,matplotlib2.02,python3

近年,非监督学习成为了研究热点。VAE(Variational Auto-Encoder,变分自编码器)和 GAN(Generative Adversarial Networks) 等模型,受到越来越多的关注

VAE:模型结构:
这里写图片描述

其中:loss = mse+KLDivergence

#coding:utf-8
"""
tensorflow 1.1
matplotlib 2.02
"""
import tensorflow as tf
import numpy as np
import input_data
import matplotlib.pyplot as plt

input_dim = 784
hidden_encoder_dim = 1200
hidden_decoder_dim = 1200
latent_dim = 200
epochs = 3000
batch_size = 100
N_pictures=3

mnist = input_data.read_data_sets('mnist/')

def weight_variable(shape):
    #tf.truncated_normal()截断的标准正态分布
    return tf.Variable(tf.truncated_normal(shape,stddev=0.001))

def bias_variable(shape):
    return tf.Variable(tf.truncated_normal(shape))

x = tf.placeholder('float32',[None,input_dim])
#在全连接层加入l2_regularization
l2_loss = tf.constant(0.0)

#encoder网络
w_encoder1 =weight_variable([input_dim,hidden_encoder_dim])
b_encoder1 = bias_variable([hidden_encoder_dim])
encoder1 = tf.nn.relu(tf.matmul(x,w_encoder1)+b_encoder1)
#第一层的l2_loss
l2_loss += tf.nn.l2_loss(w_encoder1)

#定义一个mu网络
mu_w_encoder2 = weight_variable([hidden_encoder_dim,latent_dim])
mu_b_encoder2 = bias_variable([latent_dim])
mu_encoder2 = tf.matmul(encoder1,mu_w_encoder2)+mu_b_encoder2
#mu网络的l2_loss
l2_loss += tf.nn.l2_loss(mu_w_encoder2)

#定义一个var网络
var_w_encoder2 = weight_variable([hidden_encoder_dim,latent_dim])
var_b_encoder2 = bias_variable([latent_dim])
var_encoder2 = tf.matmul(encoder1,var_w_encoder2)+var_b_encoder2
#var网络的l2_loss
l2_loss += tf.nn.l2_loss(var_w_encoder2)

#抽样
#生成标准正态分布
epsilon = tf.random_normal(tf.shape(var_encoder2))
new_var_encoder2 = tf.sqrt(tf.exp(var_encoder2))
#z的维度是latent_dim
z = mu_encoder2+tf.multiply(new_var_encoder2,epsilon)

#定义decoder网络
w_decoder1 = weight_variable([latent_dim,hidden_decoder_dim])
b_decoder1 = bias_variable([hidden_decoder_dim])
decoder1 = tf.nn.relu(tf.matmul(z,w_decoder1)+b_decoder1)
l2_loss += tf.nn.l2_loss(w_decoder1)
w_decoder2 = weight_variable([hidden_decoder_dim,input_dim])
b_decoder2 = bias_variable([input_dim])
#输出层没有使用激活函数(加入激活函数后面用log_px_given_z,不加入激活函数用cost1)
decoder2 = tf.nn.sigmoid(tf.matmul(decoder1,w_decoder2)+b_decoder2)
l2_loss += tf.nn.l2_loss(w_decoder2)

#计算cost
log_px_given_z = -tf.reduce_sum(x*tf.log(decoder2+1e-10)+(1-x)*tf.log(1-decoder2+1e-10),1)
#cost1 = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=decoder2,labels=x),reduction_indices=1)
#计算KL Divergence
KLD = -0.5*tf.reduce_sum(1+var_encoder2-tf.pow(mu_encoder2,2)-tf.exp(var_encoder2),reduction_indices=1)
cost = tf.reduce_mean(log_px_given_z+KLD)
#加上regularization 
regularized_cost = cost + l2_loss

train = tf.train.AdamOptimizer(0.01).minimize(cost)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    #画图,2行5列返回图和子图
    figure_,a = plt.subplots(2,N_pictures,figsize=(6,4))
    #开始交互模式
    plt.ion()
    #测试的图
    view_figures = mnist.test.images[:N_pictures]
    for i in range(N_pictures):
        #将图片reshape为28行28列显示
        a[0][i].imshow(np.reshape(view_figures[i],(28,28)))
        #清空x轴,y轴坐标
        a[0][i].set_xticks(())
        a[0][i].set_yticks(())
    for step in range(10000):
        batch_x,batch_y = mnist.train.next_batch(batch_size)
        #encoder3和decoder3需要进行run
        _,encoded,decoded,c = sess.run([train,z,decoder2,cost],feed_dict={x:batch_x})
        if step % 1000 ==0:
            print('= = = = = = > > > > > >','train loss:% .4f' % c)
            #将真实的图片和autoencoder后的图片对比
            decoder_figures = sess.run(decoder2,feed_dict={x:view_figures})
            for i in range(N_pictures):
                #清除第一行图片
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoder_figures[i],(28,28)))
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())
            plt.draw()
            plt.pause(1)
    plt.ioff() #关闭交互模式


"""
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    for epoch in range(epochs):
        batch_x,batch_y = mnist.train.next_batch(batch_size)
        _,c = sess.run([train,cost],feed_dict={x:batch_x})

        if epoch % 100 == 0:
            print('- - - - - - > > > > > > epoch: ',int(epoch/100),'cost: %.4f' %c)

    #输出结果可视化

    encoder_result = sess.run(z,feed_dict={x:mnist.test.images})
    plt.scatter(encoder_result[:,0],encoder_result[:,1],c = mnist.test.labels,label='mnist distributions')
    plt.legend(loc='best')
    plt.title('different mnist digits shows in figure')
    plt.colorbar()
    plt.show()

"""

结果

这里写图片描述

这里写图片描述

聚类效果:
这里写图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值