学习笔记出处
一.概念说明
编写TensorFlow需要两个步骤:
1)组装一个graph
2) 使用session去执行graph中的operation
Tensor: 类型化的多维数组
Operation:执行计算的单元,图的节点
Graph:一张有边与点的图,其表示了需要进行计算的任务
Session:称之为会话的上下文,用于执行图
二.数据结构
TensorFlow的数据结构有着rank,shape,data,types的概念
1) rank
一般是指数据的维度,其与线性代数中的rank不是一个概念。
2) shape
shape指tensor每个维度数据的个数,可以用python的list/tuple表示
3)Data type
Data type 是指单个数据的类型。常用DT_FLOAT,也就是32位浮点数
三.简单手写数字识别模型
或者通过代码获得数据:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)
该数据集中一共有550000个样本,其中50000用于训练,5000用于验证,每个样本分为X与Y两部分,其中X如下图所示,8*28的像素,在使用时需要拉伸成784维的向量
y为x真实的类别,其数据可以看做如下图的形式。因此,问题可以看成一个10分类的问题:
本次演示所使用的模型是逻辑回归,其表示为:
当使用TensorFlow进行graph构建时,大体可以分为五部分:
1)为输入x与输出y定义placeholder;
2) 定义权重w;
3) 定义模型结构;
4)定义损失函数;
5)定义优先算法;
具体步骤 :
1)首先导入需要的包,定义x与y的placeholder以及W,b的variables。其中None表示任意维度,一般是min-batch的batch size 而w定义为shape的784,10,rank为2的variable,b是shape为10,rank为1的variable。
import tensorflow as tf
x = tf.placeholder(tf.float32, [None,784])
y_ = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
2) 定义模型。x与W矩阵乘法后与b求和,经过softmax得到y
y = tf.nn.softmax(tf.matmul(x, W) + b)
3) 求逻辑回归的损失函数,这里使用了cross entropy,其公式可以表示为
这里的cross entropy取了均值。定义了学习步长为0.5,使用了梯度下降算法最小化优化函数。记得初始化Variables
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entroty)
init = tf.global_variables_initializer()
4) graph定义完后,就开始真正的计算,包括初始化变量,输入数据,并计算损失函数与利用优化算法更新参数
with tf.Session() as sess:
sess.run(init)
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
5) 迭代了1000次,每次输入了100个样本,mnist.train.next_batch 就是生成下一个batch的数据,这里知道它在干啥就行。训练结果如何需要进行评估。这里使用单纯的正确率,正确率使用最大值索引是否相等的方式,因为正确的label最大值为1,而预测的label最大值为最大概率
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuray, feed_dict={x: mnist.test.images, y_: mnist.test,labels}))