#!/usr/bin/python3 # -*-coding:utf-8 -*- # @Time :2018/3/16 # @Author :machuanbin import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np import matplotlib.pyplot as plt #超参 lr=0.001 training_epoch=20 #训练多少轮 batch_size=128 #每次训练数据多少 display_step=1 #每隔多少轮显示一次训练结果 #神经网络的参数 n_input=784 #从测试集中选择10张照片去验证自动编码器结果 examples_to_show=10 current_dir = os.path.abspath('.\MNIST_data') mnist=input_data.read_data_sets(current_dir,one_hot=True) #无监督学习,只需要输入图片 X=tf.placeholder(tf.float32,[None,n_input]) #两个隐含层 #第一个隐含层256个 #第二层128个 n_hidden_1=256 n_hidden_2=128 #设置每一层的权重和偏差 weights={ 'encoder_h1':tf.Variable(tf.random_normal([n_input,n_hidden_1])), 'encoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])), 'decoder_h1':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])), 'decoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_input])) } biases={ 'encoder_b1':tf.Variable(tf.random_normal([n_hidden_1])), 'encoder_b2':tf.Variable(tf.random_normal([n_hidden_2])), 'decoder_b1':tf.Variable(tf.random_normal([n_hidden_1])), 'decoder_b2':tf.Variable(tf.random_normal([n_input])) } #定义压缩函数 def encoder(x): layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_h1']),biases['encoder_b1'])) layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['encoder_h2']),biases['encoder_b2'])) return layer_2 def decoder(x): layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['decoder_h1']),biases['decoder_b1'])) layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['decoder_h2']),biases['decoder_b2'])) return layer_2 #构建模型 encoder_op=encoder(X) decoder_op=decoder(encoder_op) #得出预测值 y_pred=decoder_op #得出真实值 y_true=X #定义损失函数和优化器 cost=tf.reduce_mean(y_true-y_pred,2) optimizer=tf.train.RMSPropOptimizer(lr).minimize(cost) #训练数据及模型评估 with tf.Session as sess: sess.run(tf.global_variables_initializer()) total_batch=int(mnist.train.num_examples/batch_size) #开始训练 for epoch in range(training_epoch): for i in range(total_batch): batch_x,batch_y=mnist.train.next_batch(batch_size) _,c=sess.run([optimizer,cost],feed_dict={X:batch_x}) if epoch%display_step==0: print('Epoch:','%0.4d'%(epoch+1),'cost=','0.9f'.format(c)) print("Optimization Finished") #对测试集应用训练好的自动编码网络 encoder_decoder=sess.run(y_pred,feed_dict={X:mnist.test.images[:examples_to_show]}) #比较测试集原始数据和自动编码网络的重建结果 f,a=plt.subplot(2,10,figsize=(10,2)) for i in range(examples_to_show): a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28))) a[1][i].imshow(np.reshape(encoder_decoder[i],(28,28))) f.show() plt.draw() plt.waitforbuttonpress()
自编码网络实现Mnist
最新推荐文章于 2024-04-08 11:24:35 发布