#首先载入tensorflow库
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#数据放在这里
mnist=input_data.read_data_sets('./MNIST',one_hot=True)
sess=tf.InteractiveSession()
#第一个参数——数据类型,第二个参数784表示图像的像素个数,None表示图像的数量不确定
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
#为softmax模型中的weights和biases创建Variable对象(存储模型参数)
#tensor存储数据----一旦用掉就会消失,Variable在模型训练的迭代中是持久化的
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([1,10]))
#Softmax是tf.nn下的一个函数
y=tf.nn.softmax(tf.matmul(x,W)+b)
#定义Cross-entropy来计算Loss。Loss越小,表示模型的分类结果和真实值之间的偏差越小
#图片的真实分类,每张图片都有一个已经标记好的分类
y_=tf.placeholder(tf.float32,[None,10])
#cross_entropy求的就是loss
cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
#接下来需要用随机梯度下降法,不断地优化Loss
train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#使用全局参数初始化器来进行全局的初始化
tf.global_variables_initializer().run()
for i in range(1000):
batch_xs,batch_ys=mnist.train.next_batch(100)
train_step.run({x:batch_xs,y_:batch_ys})
#接下来需要判断模型是否准确
correct_predition=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_predition,tf.float32))
print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
tensorflow实战第三章——用Softmax Regression识别手写数字
最新推荐文章于 2022-02-10 15:44:46 发布