自编码网络实现Mnist

#!/usr/bin/python3
# -*-coding:utf-8 -*-
# @Time   :2018/3/16 
# @Author :machuanbin


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import  numpy as np
import matplotlib.pyplot as plt
#超参
lr=0.001
training_epoch=20 #训练多少轮
batch_size=128  #每次训练数据多少
display_step=1 #每隔多少轮显示一次训练结果


#神经网络的参数
n_input=784

#从测试集中选择10张照片去验证自动编码器结果
examples_to_show=10

current_dir = os.path.abspath('.\MNIST_data')

mnist=input_data.read_data_sets(current_dir,one_hot=True)

#无监督学习,只需要输入图片
X=tf.placeholder(tf.float32,[None,n_input])

#两个隐含层
#第一个隐含层256#第二层128
n_hidden_1=256
n_hidden_2=128

#设置每一层的权重和偏差

weights={
    'encoder_h1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),
    'encoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
    'decoder_h1':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
    'decoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_input]))
}
biases={
    'encoder_b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'encoder_b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'decoder_b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'decoder_b2':tf.Variable(tf.random_normal([n_input]))
}


#定义压缩函数
def encoder(x):
    layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_h1']),biases['encoder_b1']))
    layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['encoder_h2']),biases['encoder_b2']))
    return layer_2

def decoder(x):
    layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['decoder_h1']),biases['decoder_b1']))
    layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['decoder_h2']),biases['decoder_b2']))
    return layer_2

#构建模型
encoder_op=encoder(X)
decoder_op=decoder(encoder_op)

#得出预测值
y_pred=decoder_op
#得出真实值
y_true=X

#定义损失函数和优化器
cost=tf.reduce_mean(y_true-y_pred,2)
optimizer=tf.train.RMSPropOptimizer(lr).minimize(cost)

#训练数据及模型评估

with tf.Session as sess:
    sess.run(tf.global_variables_initializer())

    total_batch=int(mnist.train.num_examples/batch_size)

    #开始训练

    for epoch in range(training_epoch):
        for i in range(total_batch):
            batch_x,batch_y=mnist.train.next_batch(batch_size)

            _,c=sess.run([optimizer,cost],feed_dict={X:batch_x})
        if epoch%display_step==0:
            print('Epoch:','%0.4d'%(epoch+1),'cost=','0.9f'.format(c))

    print("Optimization Finished")


    #对测试集应用训练好的自动编码网络
    encoder_decoder=sess.run(y_pred,feed_dict={X:mnist.test.images[:examples_to_show]})

    #比较测试集原始数据和自动编码网络的重建结果
    f,a=plt.subplot(2,10,figsize=(10,2))
    for i in range(examples_to_show):
        a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)))
        a[1][i].imshow(np.reshape(encoder_decoder[i],(28,28)))

    f.show()
    plt.draw()
    plt.waitforbuttonpress()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值