Tensorflow实现MNIST数据自编码(2)

对自编码(1)进行改进,(1)中实现的网络是用2个编码层,2个解码层,现在对它进行添加编码层和解码层分别为4层

原始数据784--》256--》64--》16--》2

  1. #encoding=utf-8  
  2. import tensorflow as tf   
  3. from tensorflow.examples.tutorials.mnist import input_data  
  4. mnist = input_data.read_data_sets('/data',one_hot=True)  
  5.   
  6. learning_rate = 0.01  
  7. n_hidden_1 = 256  
  8. n_hidden_2 = 64  
  9. n_hidden_3 = 16  
  10. n_hidden_4 = 2  
  11. n_input = 784 #输入图片大小28*28  
  12.   
  13. x = tf.placeholder('float',[None,n_input])  
  14. y = x  
  15.   
  16. weights = {  
  17.     'encoder_h1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),  
  18.     'encoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),  
  19.     'encoder_h3':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_3])),  
  20.     'encoder_h4':tf.Variable(tf.random_normal([n_hidden_3,n_hidden_4])),  
  21.     'decoder_h1':tf.Variable(tf.random_normal([n_hidden_4,n_hidden_3])),  
  22.     'decoder_h2':tf.Variable(tf.random_normal([n_hidden_3,n_hidden_2])),  
  23.     'decoder_h3':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),  
  24.     'decoder_h4':tf.Variable(tf.random_normal([n_hidden_1,n_input])),  
  25. }  
  26. biases = {  
  27.     'encoder_b1':tf.Variable(tf.zeros([n_hidden_1])),  
  28.     'encoder_b2':tf.Variable(tf.zeros([n_hidden_2])),  
  29.     'encoder_b3':tf.Variable(tf.zeros([n_hidden_3])),  
  30.     'encoder_b4':tf.Variable(tf.zeros([n_hidden_4])),  
  31.     'decoder_b1':tf.Variable(tf.zeros([n_hidden_3])),  
  32.     'decoder_b2':tf.Variable(tf.zeros([n_hidden_2])),  
  33.     'decoder_b3':tf.Variable(tf.zeros([n_hidden_1])),  
  34.     'decoder_b4':tf.Variable(tf.zeros([n_input])),  
  35. }  
  36.   
  37. #定义网络模型  
  38. def encoder(x):  
  39.     layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_h1']),biases['encoder_b1']))  
  40.     layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['encoder_h2']),biases['encoder_b2']))  
  41.     layer_3 = tf.nn.sigmoid(tf.add(tf.matmul(layer_2,weights['encoder_h3']),biases['encoder_b3']))  
  42.     layer_4 = tf.nn.sigmoid(tf.add(tf.matmul(layer_3,weights['encoder_h4']),biases['encoder_b4']))  
  43.     return layer_4  
  44. def decoder(x):  
  45.     layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']), biases['decoder_b1']))  
  46.     layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']), biases['decoder_b2']))  
  47.     layer_3 = tf.nn.sigmoid(tf.add(tf.matmul(layer_2, weights['decoder_h3']), biases['decoder_b3']))  
  48.     layer_4 = tf.nn.sigmoid(tf.add(tf.matmul(layer_3, weights['decoder_h4']), biases['decoder_b4']))  
  49.     return layer_4  
  50.   
  51. y_pred = decoder(encoder(x))  
  52. print('y_pred',y_pred)  
  53.   
  54. cost = tf.reduce_mean(tf.pow(y-y_pred,2))  
  55. optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)  
  56.   
  57. training_epochs = 20  
  58. batch_size = 256  
  59. display_step = 1  
  60.   
  61. with tf.Session() as sess:  
  62.     tf.global_variables_initializer().run()  
  63.     total_batch = int(mnist.train.num_examples/batch_size)  
  64.     #循环开始训练  
  65.     for epoch in range(training_epochs):  
  66.         #遍历全部数据集  
  67.         for i in range(total_batch):  
  68.             batch_xs,batch_ys = mnist.train.next_batch(batch_size)  
  69.             _,c = sess.run([optimizer,cost],feed_dict={x:batch_xs})  
  70.         if epoch%display_step == 0:  
  71.             print('Epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(c))  
  72.     print('finished')  

1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md或论文文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。 5、资源来自互联网采集,如有侵权,私聊博主删除。 6、可私信博主看论文后选择购买源代码。 1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md或论文文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。 5、资源来自互联网采集,如有侵权,私聊博主删除。 6、可私信博主看论文后选择购买源代码。 1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md或论文文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。 5、资源来自互联网采集,如有侵权,私聊博主删除。 6、可私信博主看论文后选择购买源代码。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值