利用tensorflow实现MNIST手写数字识别(单层神经网络)

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
# 参数1:下载到的目录,参数2:标签数据格式是不一样的
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# print('训练集train的数量:', mnist.train.num_examples,
#       '测试集validation的数量:', mnist.validation.num_examples,
#       '验证集test的数量:', mnist.test.num_examples)
# print('train images shape:', mnist.train.images.shape,
#       'labels shape:', mnist.train.labels.shape)
# print(mnist.train.images[0].reshape(28,28))
# 对数据进行可视化
# import matplotlib.pyplot as plt
# def plot_image(image):
#       plt.imshow(image.reshape(28, 28),cmap='binary')
#       plt.show()
#
# print(plot_image(mnist.train.images[1000]))

# 数据的批量读取,一次读入十个数据
# next_batch()会对内部的数据先做buffle
# batch_images_xs, batch_labels_ys = mnist.train.next_batch(batch_size=10)
# print(batch_images_xs.shape,batch_labels_ys.shape)

# 定义待输入数据的占位符
# mnist数据集中的每张图片共有28x28=784个像素点
x = tf.placeholder(tf.float32, [None, 784], name='X')
# 0-9一共10个数字,10个类别
y = tf.placeholder(tf.float32, [None, 10], name='Y')

# 以正态分布的随机数初始化权重w
W = tf.Variable(tf.random_normal([784, 10]), name='W')
# 以常数0初始化偏置值b
b = tf.Variable(tf.zeros([10]), name='b')
# 用单个神经元构建神经网络,前向计算
forward = tf.matmul(x, W) + b
# softMax分类
pred = tf.nn.softmax(forward)

# 训练模型
# 设置训练参数
train_epochs = 50 #训练轮数
batch_size = 100 #单次训练样本数(批次大小)
# 一轮训练的批次数
total_batch = int(mnist.train.num_examples/batch_size)
display_step = 1 #显示粒数
learning_rate = 0.01 #学习率

# 定义交叉熵损失函数
loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))

# 选择优化器
# 梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

# 定义准确率
# 检查预测类别和实际类别的匹配情况,tf.argmax()参数2:1表示列
correct_prediction = tf.equal(tf.argmax(pred, 1),tf.argmax(y, 1))
# 准确率,将布尔值转化为浮点值,tf.cast()转化数据类型
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 声明会话
sess = tf.Session()
# 变量初始化
init = tf.global_variables_initializer()
sess.run(init)

# 训练模型
for epoch in range(train_epochs):
      for batch in range(total_batch):
            # 读取批次训练数据
            xs, ys = mnist.train.next_batch(batch_size)
            # 执行批次训练
            sess.run(optimizer,feed_dict={x:xs, y:ys})
      # 在total_batch批次数据训练完成后,使用验证数据计算误差和准确率,验证集不分批
      loss, acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
      # 打印训练过程中的详细信息
      if (epoch+1) % display_step == 0:
            print('训练轮次:','%02d' % (epoch+1),
                  '损失:','{:.9f}'.format(loss),
                  '准确率:','{:.4f}'.format(acc))

print('训练结束')

# 评估模型
accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images, y:mnist.test.labels})
print('测试集上准确率:', accu_test)

accu_validation = sess.run(accuracy,feed_dict={x:mnist.validation.images, y:mnist.validation.labels})
print('验证集上准确率:', accu_validation)

accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images, y:mnist.train.labels})
print('训练集上准确率:', accu_train)

# 由于预测结果是one-hot编码,需要转换为0-9数字
prediction_result = sess.run(tf.argmax(pred, 1), feed_dict={x:mnist.test.images})

# 查看预测结果的前10项
print(prediction_result[0:10])

# 可视化
import matplotlib.pyplot as plt
import numpy as np

def plot_images_labels_prediction(images,      #图像列表
                                   labels,      #标签列表
                                   predication, #预测值列表
                                   index,       #从第index个开始显示
                                   num=10):    # 缺省一次显示10幅
      fig = plt.gcf()               #获取当前图表,get current figure
      fig.set_size_inches(10,12)    #设为英寸,1英寸=2.53厘米
      if num > 25:
            num = 25          #最多显示25个子图
      for i in range(0,num):
            ax = plt.subplot(5,5, i+1)    #获取当前要处理的子图
            # 显示第index图像
            ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')

            # 构建该图上显示的title
            title = 'label=' + str(np.argmax(labels[index]))
            if len(predication) > 0:
                  title += ",predict=" + str(predication[index])

            # 显示图上的title信息
            ax.set_title(title,fontsize=10)
            ax.set_xticks([]) # 不显示坐标轴
            ax.set_yticks([])
            index += 1

      plt.show()

plot_images_labels_prediction(mnist.test.images,
                               mnist.test.labels,
                               prediction_result, 10,25)

运行截图为

 

predict为利用该模型的预测值,label为标记值 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值