TensorFlow入门的第一个程序
TensorFlow是谷歌深度学习的框架,暑假培训对框架有了初步的认识,也对人工智能有了新的认识,通过这个手写数字的识别,对机器学习有了新的认识,机器学习:即对大量数据(样本)进行分析,对样本进行降维,通过不同的算法提取特征值,通过大量数据对模型进行拟合训练,类似在 y = w *x +b 中,通过利用feed_dict 对占位符x与y_ 进行赋值 ,模型通过计算得到预测值,然后用预测值与我们所给实际值 y 进行对比,利用梯度下降法不断对开销进行优化,不断寻找局部最优解,达到对模型的训练,在机器学习中,样本通常进行80% 用来训练, 20%用来测试
其中TensorFlow框架中不能直接输出,例如 x = a+ b ,在框架中需要调用 “会话”去计算或者输出
首先: sess= tf.Session() 创建对话
sess.run(x)
以下是手写数字识别的代码:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
#创建会话
sess = tf.Session()
#创建两个占位符,在这里x是图片,在这里784= 28* 28 ,None 为任意数量图片,由于测试集合较大,所以在后面规定100张为一个最小集合进行训练,在这里利用784 是将图片的二维结构降为一维,因为在灰度图像中0 为白色 ,1 为黑色 ,所以将为一个一维进行处理,然而y_ 则为正确结果,在这里我们是预测手写数字的数字是多少,在这里我们是通过识别0-9 这10个数子,当数字为 1 时,y_ 为[0,1,0,0,0,0,0,0,0,0]
x=tf.placeholder(tf.float32, [None ,784])
y_ = tf.placeholder(tf.float32 , [None,10])
#下载数据
mnist = input_data.read_data_sets("MNIST_data/",one_hot =True)
#在这里 y_ = w * x + b, 其中因为 y,x 不是一维的数字,所以我们在产生 w 与 b 的时候需要与 x y_ 一致, 这里x是一个二维的,所以根据矩阵的乘法, 我们定义w 也是二维,并且为[784,10] , 而b则与 y_ 一致即可 [10].
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
#softmax() 函数是一个比较有意思的函数吧,softmax Regression 会话是对每一类别估算一个概率,比如识别某张图片中1的概率为10%,而8的概率为 80%, 则最后选取概率最大的那个数字作为模型的输出结果
#在这里需要注意的是matmul(x,w) 与multiply ( ),简单的来说matmul ( )就是矩阵的运算
y = tf.nn.softmax(tf.matmul(x,W)+b)
#定义损失函数,在这里reduce_mean()是求均值,reduce_sum( )是求和
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
#模型的训练,常用的有随机梯度下降SDG,定义好优化算法后,TensorFlow 根据我们定义的这个计算图自动求导,并且根据反向传播算法进行训练,在每次迭代中更新参数来不断减小loss
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#定义初始化函数,在后面通过会话调用。
init = tf.global_variables_initializer()
sess.run(init)
#进行 1000 次训练 ,通过feed_dict 为占位符进行赋值
for i in range (1000):
batch_xs ,batch_ys = mnist.train.next_batch(100)
sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys})
#进行验证, argmax()找出最大值的序号,其中argmax(y,1)是寻找预测值,argmax(y_,1)是寻找真实值,最后返回计算分类是否正确的的类别
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print(sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys}))
print(batch_xs.shape)
最终识别准确率可以在90%左右。