Tensorflow 实现softmax识别手写数字
导入数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot= True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
查看这个数据集的样式
* 我们可以看出,训练集的X是有55000张图片,每张图片有784个维度,训练集的Y有55000个值,每个值有十个维度(分别代表着十个数字)
print(mnist.train.images.shape,mnist.train.labels.shape) # 查看训练集的样式,
(55000, 784) (55000, 10)
print(mnist.test.images.shape,mnist.test.labels.shape) # 查看测试集的样式
(10000, 784) (10000, 10)
定义X,W,B,Y
import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32,[None,784])
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,w) + b)
定义损失函数
* 这里我们选用交叉熵作为损失函数
y_ = tf.placeholder(tf.float32,[None,10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
#其中tf.reduce_mean()是对每个batch求平均值,tf.reduce_sum()是为了求每一个batch的交叉熵的和
定义优化算法
* 在这里我们使用随机梯度下降作为优化算法
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 定义优化不步长是0.5优化的对象是cross_entity
全部定义完毕,我们可以使用全局参数初始化器,并执行他的run方法
tf.global_variables_initializer().run()
每次选出100个数据喂入我们的训练器
for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100)
train_step.run({x:batch_xs,y_:batch_ys})#这里利用train_step.run()来运行优化算法
查看准确率
* 用tf.equal()判断两个值是否相等
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
查看一个张量的值,用tensor.eval(),eval里面可以加字典参数
其中若tensor是一个未知数,则需要喂入一个字典,也就是字典中的参数能够求得tensor中的未知数,在下面,我们不知道y,但是我们知道y可以由x求得,所以我们可以在字典参数中写入x:batch_xs,所以可求
batch_xs,batch_ys = mnist.train.next_batch(100)
print(tf.argmax(y,1).eval({x:batch_xs}))
[2 3 2 9 9 0 1 2 7 9 5 1 1 6 8 7 4 2 3 9 0 4 7 5 7 0 1 1 2 5 7 1 6 0 1 1 1
8 2 1 9 7 1 6 7 0 9 7 1 9 9 2 5 2 7 4 9 2 3 7 2 4 2 3 2 9 3 7 3 3 6 7 8 6
0 5 0 1 9 7 4 1 8 7 8 6 4 6 6 8 6 0 1 9 5 9 8 4 2 2]
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# 将数据转化为float32
print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
0.9188