Tensorflow_MNIST数据集上结果改进

原创 2018年04月17日 14:24:18
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#导入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
batch_size = 100
n_batch = mnist.train.num_examples // batch_size

#定义placeholder
x= tf.placeholder(tf.float32,[None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
lr = tf.Variable(0.001, dtype=tf.float32)

#构造神经网络
W1 = tf.Variable(tf.truncated_normal([784,500], stddev = 0.1))
b1 = tf.Variable(tf.zeros([500])+0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1)+b1)
L1_drop = tf.nn.dropout(L1, keep_prob)

W2 = tf.Variable(tf.truncated_normal([500, 300], stddev=0.1))
b2 = tf.Variable(tf.zeros([300])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2)+b2)
L2_drop = tf.nn.dropout(L2, keep_prob)

W3 = tf.Variable(tf.truncated_normal([300, 10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop, W3)+b3)

#定义损失函数和训练的优化器
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
train_step = tf.train.AdamOptimizer(lr).minimize(loss)

#变量初始化
init = tf.global_variables_initializer()

#计算准确率
correction_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))

#创建会话,分批迭代训练,每次迭代调整lr,并计算识别准确率
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(51):
        sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch)))
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys, keep_prob:1.0})
            
        learning_rate = sess.run(lr)
        acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels, keep_prob:1.0})
        print("Iter " + str(epoch) + ", Testing Accuracy " + str(acc) + ", Learing Rate " + str(learning_rate))

Iter 50, Testing Accuracy 0.9815, Learing Rate 7.6944976e-05

结果有待进一步改进。。。。

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lansetiankong2104/article/details/79974284

TensorFlow读取MNIST数据集错误的问题

TensorFlow读取mnist数据集错误的问题 运行程序出现”URLError”错误的问题 可能是服务器或路径的原因,可以自行下载数据集后,将数据集放到代码所在的文件夹下,并将路径改为: ...
  • jiaoyangwm
  • jiaoyangwm
  • 2018-02-04 20:53:56
  • 70

tensorflow_mnist数据集卷积神经网络实例

程序1: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import ma...
  • fireflychh
  • fireflychh
  • 2017-06-27 14:37:17
  • 104

RLS improved whitening filter

  • 2010年09月18日 17:39
  • 423KB
  • 下载

周志华机器学习,3.3编程实现对率回归,并给出西瓜数据集3.0α上的结果

3.3编程实现对率回归,并给出西瓜数据集3.0α上的结果数据集: 1 0.697 0.460 1 2 0.774 0.376 1 3 0.634 0.264 1 4 0.608 0.318 1 ...
  • zjy_lilas
  • zjy_lilas
  • 2017-03-21 11:35:54
  • 2778

Tensorflow_MNIST

  • 2017年09月12日 16:37
  • 2KB
  • 下载

周志华《机器学习》课后习题解答系列(四):Ch3.5 - 编程实现线性判别分析

3.5 编程实现线性判别分析(LDA),这里采用基于sklearn和自己编程实现两种方式实现线性判别。...
  • Snoopy_Yuan
  • Snoopy_Yuan
  • 2017-03-21 16:32:21
  • 1488

读懂《机器学习实战》代码—K-近邻算法改进约会网站配对效果

从上一篇文章大概了解了K-近邻算法的原理,并实现了分类函数: #inX为用于分类的输入向量 #dataSet为输入的训练样本集 #lables为标签向量 #参数k表示用于选择最近邻居的数...
  • u013457382
  • u013457382
  • 2016-03-20 18:38:03
  • 2006

k-近邻算法改进约会网站的配对效果

在上一篇的基础上增加如下代码:''' 将文本记录转换到NumPy的解析程序 输入为文件名字符串 输出为训练样本矩阵和类标签向量 ''' def file2matrix(filename): f...
  • u012319493
  • u012319493
  • 2016-05-14 18:18:14
  • 1135

使用不同的SVM对iris数据集进行分类并绘出结果

使用不同的SVM对iris数据集进行分类并绘出结果标签: 机器学习 Python译文之前的碎碎念SVM学习了也有一段时间了,公式基本都推导了一遍,明显感觉SVM的推导过程比之前学习的机器学习模型的推导...
  • xuelabizp
  • xuelabizp
  • 2016-04-11 16:03:54
  • 6773
收藏助手
不良信息举报
您举报文章:Tensorflow_MNIST数据集上结果改进
举报原因:
原因补充:

(最多只允许输入30个字)