基于tensorflow使用softmax算法训练手写数字MNIST识别任务

                                                   基于tensorflow使用softmax算法实现手写数字MNIST识别任务

利用深度学习解决的任务无外乎以下几个关键步骤:

1、明确任务需求,准备训练数据

2、搭建网络模型

3、计算损失函数

4、定义优化器

5、迭代训练

6、测试

 

1、明确任务,准备数据

           由于手写数字识别的任务需求是构建算法模型使其具有对手写的从0到9一共10个数字进行识别的功能。所以训练样本就是一定数量的0到9的数字图片,标签label就是数字0到9一共10类。比如数字图片3,它的label就是3:

                                                                         

手写数字识别一般都是采用MNIST数据集:

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

1)Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本);

2)Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

3)Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)

4)Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。

由于本篇博客是识别手写数字,所以直接利用代码下载MNIST数据集作为任务的训练集,其实做深度学习的任务训练集的构建是十分关键的一步,后面博客会逐步讲解如何构建自己的训练集进行各种深度学习任务的尝试和研究。

这里下载MNIST数据集很简单,如下所示:

from tensorflow.examples.tutorials.mnist import input_data
mnist =input_data.read_data_sets("MNIST_data/",one_hot=True)

one_hot设置为true,表示采用one_hot编码。

批量获取训练集和测试集以及对应的标签:

batch_xs,batch_ys=mnist.train.next_batch(batch_size)
batch_xs,batch_ys=mnist.test.next_batch(batch_size)

下载和打印查看MINST数据集的完整代码如下:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf

#download dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist =input_data.read_data_sets("MNIST_data/",one_hot=True)

#print message of MNIST
print('train_datas:',mnist.train.images)
print('train_shape:',mnist.train.images.shape)
import pylab
im=mnist.train.images[1]#索引第一张图
im=im.reshape(-1,28)
pylab.imshow(im)
pylab.show()

#打印MNIST测试集和验证集的信息
print('test_shape:',mnist.test.images.shape)
print('val_shape:',mnist.validation.images.shape)

2、搭建网络模型

我这里就没有搭建比较复杂的神经网络(下一篇博客会搭建LeNet5网络进行手写数字的识别),就搭建一个线性函数层然后利用softmax进行分类:

# 定义输入输出占位符
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])

# 构建模型
w=tf.Variable(tf.random_normal([784,10]))
b=tf.Variable(tf.zeros([10]))
# 正向传播
net = tf.matmul(x,w)+b
pred=tf.nn.softmax(net)

说明:

  1. x与y的维数为什么是那样的?

因为MNIST数据集图片的尺寸都为28*28*1=784,是灰度图,所以通道数是1,None是为了适配训练时可设定需要的batch的维数,由于分类的类别是10类,所以y的维数是batch_sizex10。

  1. w和b是什么?

w是待学习的权重,b是待学习的偏置。

 

3、计算损失函数

设计的是交叉熵损失函数,当交叉熵损失函数采用的是softmax计算的预测值,那么就叫softmax交叉熵损失函数:

cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))

4、定义优化器

# 定义优化器
learning_rate=0.01
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

5、迭代训练

       设置训练的epoch数(注:一个epoch就是迭代遍历一遍训练集)以及batch_size,启动session进行训练,其他的代码,比如打印训练信息和保存模型等,直接请看下面的完整代码示例,关键代码都有注释说明,一目了然。

完整的训练代码示例:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import pylab

# 构建训练集
from tensorflow.examples.tutorials.mnist import input_data
mnist =input_data.read_data_sets("MNIST_data/",one_hot=True)

tf.reset_default_graph()
# 定义输入输出占位符
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])

# 构建模型
w=tf.Variable(tf.random_normal([784,10]))
b=tf.Variable(tf.zeros([10]))
# 正向传播
net = tf.matmul(x,w)+b
pred=tf.nn.softmax(net)
# 反向传播
cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
# 定义优化器
learning_rate=0.01
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

training_epochs=500
batch_size=64
display_step=1

model_path="log/MnistClass_model.ckpt"
# 启动session,迭代训练
saver = tf.train.Saver(max_to_keep=10)# max_to_keep是控制最大的模型保存数
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 启动循环开始训练
    for epoch in range(training_epochs):
        avg_cost=0
        total_batch=int(mnist.train.num_examples/batch_size)
        # 循环所有数据集
        for i in range(total_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            # 运行优化器
            _,c=sess.run([optimizer,cost],feed_dict={x:batch_xs,y:batch_ys})
            # 计算平均loss值
            avg_cost+=c/total_batch
            # 显示训练中的详细信息
            if(epoch+1)%display_step==0:
                print("Epoch:",'%04d'%(epoch+1),"cost=","{:.9f}".format(avg_cost))
                format(avg_cost)

    print("Finished!")
    # 测试model
    correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
    # 计算准确率
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    print("Accuracy:",accuracy.eval({x:mnist.test.images,y:mnist.test.labels}))

    # 保存模型
    save_path=saver.save(sess,model_path)
    print("Model saved in file: %s" % save_path)


6、测试

测试时需要保证网络模型与训练是一致的,并且需要加载已经训练好的ckpt模型文件(后面博客会讲到如何把ckpt模型转成pb模型,那时测试会更方便,并且只有转成pb才方便部署。),主要用到saver.restore函数来加载ckpt模型,直接见完整的测试代码示例:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import pylab

from tensorflow.examples.tutorials.mnist import input_data
mnist =input_data.read_data_sets("MNIST_data/",one_hot=True)

tf.reset_default_graph()
#定义占位符
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])

#构建模型
w=tf.Variable(tf.random_normal([784,10]))
b=tf.Variable(tf.zeros([10]))
#正向传播
pred=tf.nn.softmax(tf.matmul(x,w)+b)


model_path="log/MnistClass_model.ckpt"
saver=tf.train.Saver()
#启动session
with tf.Session() as sess:
    # 初始化变量
    sess.run(tf.global_variables_initializer())
    #恢复模型变量
    saver.restore(sess,model_path)
    #测试model
    correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
    #计算准确率
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#cast:类型转换
    print("Accuracy:",accuracy.eval({x:mnist.test.images,y:mnist.test.labels}))

    output=tf.argmax(pred,1)
    batch_xs,batch_ys=mnist.train.next_batch(2)
    outputval,predv=sess.run([output,pred],feed_dict={x:batch_xs})
    print(outputval,predv,batch_ys)

    im=batch_xs[0]
    im=im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show()

    im=batch_xs[1]
    im=im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show(

有问题欢迎评论交流,一起进步!

参考:

  1. MNIST数据集说明参考自:https://www.cnblogs.com/xianhan/p/9145966.html
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值