3_2_HelloWorld

一完整代码

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)

import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

tf.global_variables_initializer().run()

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run({x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
二代码解释
#MNIST手写数字识别来探索TensorFlow.MNIST(Mixed National Institute of Standards and Technology database)是一个非常简单的机器视觉数据集
#由几万张28像素*28像素的手写数字组成,这些图片只包含灰度值信息。我们的任务是对这些手写数字的图片进行分类

#首先对MNIST数据进行加载,TensorFlow为我们提供了一个非常方便的封装,可以直接加载MNIST数据呈我们期望的格式,在ipython命令行或者spyder中直接运行下面的代码
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#查看mnist这个数据集可以看到训练集有55000个样本,测试集有10000个样本,同时验证集有5000个样本。每个样本都有它对应的标注信息,即label
#我们会在训练集上训练模型,在验证集上校验效果并决定何时完成训练,最后我们在测试集评测模型的效果(可通过准确率,召回率,F1-score等评测)

print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)

import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])

#W的shape是[784,10],784是特征的维数,而后面的10代表有10类,因为Label在one-hot编码后是10维的向量。
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

#实现Softmax Regression算法,y=softmax(Wx+b),改写成下面的代码
#softmax是tf.nn下面的一个函数,而tf.nn包含了大量神经网络的组件,tf.matmul是TensorFlow中的矩阵乘法函数。
#TensorFlow最厉害的地方在于将forward和backward的内容都自动实现,只要接下来定义好loss,训练时将会自动求导并进行梯度下滑,
#完成对Softmax Regression模型参数的自动学习
y = tf.nn.softmax(tf.matmul(x, W) + b)

#为了训练模型,我们需要定义一个loss function来描述模型对问题的分类京都。Loss越小,代表模型的分类结果与真实值的偏差越小,也就是模型越精确。
#开始给模型全部填充了全零的参数,这样模型会有一个初始的loss,而训练的目的是不断将这个loss减少,直到达到一个全局最优或者局部最优解。
#通常使用cross-entropy作为loss function.
#cross-entropy定义如下,其中y是预测的概率分布,y'是真实的概率分布,通常用来判断模型对真实概率分布估计的准确程度。

#先定义一个placeholder,输入时真实的label, 用来计算cross-entropy.这里的y_ * tf.log(y)也就是前面的y'ilog(yi),
#tf.reduce_sum也及时求和,而tf.reduce_mean则用来对每个batch数据结果求均值
y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

#我们直接调用tf.train.GradientDescentOptimizer,并设置学习速率为0.5,优化目标设置为cross-entropy,得到进行训练的操作train_step.
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#使用Tensorflow的全局参数初始化器tf.global_variables_initializer,并且直接运行它的run方法
tf.global_variables_initializer().run()

#对模型完成了训练
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run({x: batch_xs, y_: batch_ys})

#对模型的准确率进行验证。tf.argmax是从一个 tensor中寻找最大值的序号,tf.argmax(y,1)就是求各个预测的数字中概率最大的那个,
#tf.argmax(y_,1)则是找样本的真实数字类别。而tf.equal方法来判断预测的数字类别是否就是正确的类别,最后返回计算分类是正确的操作correct_predition
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

#我们统计全部样本预测的accuracy,这里需要先用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均。
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#我们将测试数据的特征和Label输入评测流程accuracy,计算模型再测试集上的准确率,再将结果打印出来。
#使用Softmax Regression对Softmax Regression对MNIST数据进行分类识别,在测试集上平均准确率可达92%
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

#一般步骤,我们可以做的事情可以分为4个部分
(1)定义算法公式,也就是神经网络forward时的计算。
(2)定义loss,选定优化器,并指定优化器优化loss
(3)迭代地对数据进行训练
(4)在测试集或验证集上对准确率进行评测。
神经网络的最大价值在于对特征的自动提取和抽象,它免去了人工提取特征的繁琐
三代码执行结果
hadoop@zhangjinyuTensor-1:~/zhang$ python3.6 3_2_HelloWorld.py
WARNING:tensorflow:From 3_2_HelloWorld.py:17: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/hadoop/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/hadoop/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /home/hadoop/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/hadoop/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/hadoop/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
(55000, 784) (55000, 10)
(10000, 784) (10000, 10)
(5000, 784) (5000, 10)
2018-12-04 01:09:27.943357: I tensorflow/core/common_runtime/process_util.cc:69] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
0.9214
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值