TensorFlow学习笔记(二)——实例Softmax Regression识别手写数字

    在下学习TensorFlow的主要参考资料是电子工业出版社的[《TensorFlow实战》](https://book.douban.com/subject/26974266/),实例也基本都来自这本书。

使用TensorFlow的4个基本步骤

  1. 定义算法公式,即神经网络前向计算的公式
  2. 定义损失函数,选定优化器,并指定优化器优化损失函数
  3. 迭代地对数据进行训练
  4. 在测试集或验证集上对准确率进行评测

关于One-Hot编码

    在机器学习的应用任务中,对于非连续的数据经常也会使用数字进行编码,便于处理。例如“男性”编码为1,“女性”编码为2。但是这二者之间是不存在数学上的连续关系的,然而如果按照上述1和2进行编码的话,机器学习算法会认为“男性”和“女性”之间存在数学的有序关系。
    独热编码即One-Hot编码,又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候,其中只有一位有效。例如上文中说的“男性”和“女性”共有两种状态,那么就可以编码为01和10,对于有N个状态的特征,经过one-hot编码后就会变成N个二元值,而其中只有一个为1。
采用one-hot编码的好处主要有: 
  1. 解决了分类器不好处理属性数据的问题 
  2. 在一定程度上也起到了扩充特征的作用

实例:Softmax Regression识别手写数字

    代码如下,来自《TensorFlow实战》第三章,说明在代码注释中:
# 下载数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  # 使用one-hot编码
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)


import tensorflow as tf
sess = tf.InteractiveSession()
# 第一步,定义算法公式
x = tf.placeholder(tf.float32, [None, 784])  # 构建占位符,None表示样本的数量可以是任意的
W = tf.Variable(tf.zeros([784, 10]))  # 构建一个变量,代表权重矩阵,初始化为0
b = tf.Variable(tf.zeros([10]))  # 构建一个变量,代表偏置,初始化为0
y = tf.nn.softmax(tf.matmul(x, W) + b)  # 构建了一个softmax的模型:y = softmax(Wx + b),y指样本标签的预测值

# 第二步,定义损失函数,选定优化器,并指定优化器优化损失函数
y_ = tf.placeholder(tf.float32, [None, 10])
# 交叉熵损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
# 使用梯度下降法最小化cross_entropy损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 第三步,迭代地对数据进行训练
tf.global_variables_initializer().run()
for i in range(1000):  # 迭代次数1000
    batch_xs, batch_ys = mnist.train.next_batch(100)  # 使用minibatch,一个batch大小为100
    train_step.run({x: batch_xs, y_: batch_ys})

# 第四步,在测试集或验证集上对准确率进行评测
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  # tf.argmax()返回的是某一维度上其数据最大所在的索引值,在这里即代表预测值和真值
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  # 用平均值来统计测试准确率
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))  # 打印测试信息


  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值