训练一个机器学习模型用于预测图片里面的数字

# 预测准确率# MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片# 它也包含每一张图片对应的标签,告诉我们这个是数字几。比如,四张图片的标签分别是5,0,4,1。## 此教程将训练一个机器学习模型用于预测图片里面的数字。# 目的不是要设计一个世界一流的复杂模型# 而是介绍下如何使用TensorFlow。所以,这里会从一个很简单的数学模型开始,它叫做Softmax Regression。import input_data# use datamnist = input_.
摘要由CSDN通过智能技术生成
# 预测准确率
# MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片
# 它也包含每一张图片对应的标签,告诉我们这个是数字几。比如,四张图片的标签分别是5,0,4,1。
#
# 此教程将训练一个机器学习模型用于预测图片里面的数字。
# 目的不是要设计一个世界一流的复杂模型
# 而是介绍下如何使用TensorFlow。所以,这里会从一个很简单的数学模型开始,它叫做Softmax Regression。


import input_data
# use data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

import tensorflow as tf

# before open session。没有这行代码会报错
tf.compat.v1.disable_eager_execution()

# set x,W,b. Make y = Wx + b
x = tf.compat.v1.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)  # y=wx+b

# 交叉熵产生于信息论里面的信息压缩编码技术,但是它后来演变成为从博弈论到机器学习等其他领域里的重要技术手段
# y 是我们预测的概率分布, y' 是实际的分布(我们输入的one-hot vector)。
# 比较粗糙的理解是,交叉熵是用来衡量我们的预测用于描述真相的低效性。
y_ = tf.compat.v1.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.compat.v1.log(y))

# 用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵
train_step = tf.compat.v1.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 初始化我们创建的变量
init = tf.compat.v1.initialize_all_variables()

# 在一个Session里面启动我们的模型,并且初始化变量
sess = tf.compat.v1.Session()
sess.run(init)

# 让模型循环训练1000次
# 该循环的每个步骤会随机抓取训练数据中的100个批处理数据点,然后用这些数据点作为参数替换之前的占位符来运行train_step。
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。
# 由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签
# 比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签
# 可以用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

# 这行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。例如,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 计算所学习到的模型在测试数据集上面的正确率
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

 

附上input_data.py

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =================================================
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值