【tensorflow扩展库学习】MNIST手写数字分类

【tensorflow扩展库学习】MNIST手写数字分类

   项目介绍

        建立深度神经网络,将手写的数字(0到9)正确的分类。
       
        使用MNIST手写数字数据集,使用内置的使用方法即使检索数据集。
       
        尝试使用tensorflow的拓展库:tensorflow.contrib.learn实现,简化一般方法的实现过程
   

   使用contrib.learn实现

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.contrib import learn
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt

# 感谢网友的封装,原函数地址:https://blog.csdn.net/qq7835144/article/details/88952512
def cm_plot(original_label, predict_label, pic=None):
    cm = confusion_matrix(original_label, predict_label)   # 由原标签和预测标签生成混淆矩阵
    plt.figure()
    plt.matshow(cm, cmap=plt.cm.Blues)     # 画混淆矩阵,配色风格使用cm.Blues
    plt.colorbar()    # 颜色标签
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center')
            # annotate主要在图形中添加注释
            # 第一个参数添加注释
            # 第二个参数是注释的内容
            # xy设置箭头尖的坐标
            # horizontalalignment水平对齐
            # verticalalignment垂直对齐
            # 其余常用参数如下:
            # xytext设置注释内容显示的起始位置
            # arrowprops 用来设置箭头
            # facecolor 设置箭头的颜色
            # headlength 箭头的头的长度
            # headwidth 箭头的宽度
            # width 箭身的宽度
    plt.ylabel('True label')  # 坐标轴标签
    plt.xlabel('Predicted label')  # 坐标轴标签
    plt.title('confusion matrix')
    if pic is not None:
        plt.savefig(str(pic) + '.jpg')
    plt.show()

# 作为MNIST数据集的存储路径
DATA_DIR = 'c:\\tmp\\data'
# 导入数据集,赋值给相应变量
data = input_data.read_data_sets(DATA_DIR, one_hot=False)
x_data, y_data = data.train.images, data.train.labels.astype(np.int32)
x_test, y_test = data.test.images, data.test.labels.astype(np.int32)

# 设置训练次数
NUM_STEPS = 2000
# 设置单次训练大小
MINIBATCH_SIZE = 128

# 为定义的输入' x_data '创建' feature_columns '对象
feature_columns = learn.infer_real_valued_columns_from_input(x_data)

# 建立深度神经网络分类器,200个隐藏单元,10各类别,学习率为0.2
dnn = learn.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[200],
    n_classes=10,
    optimizer=tf.train.ProximalAdagradOptimizer(
        learning_rate=0.2
    ))

# 使用fit()对分类器对象进行训练。我们向它传递协变量和目标变量,并设置步数和批量大小:
dnn.fit(x=x_data, y=y_data, steps=NUM_STEPS, batch_size=MINIBATCH_SIZE)

# 计算精确度,并输出
test_acc = dnn.evaluate(x=x_test, y=y_test, steps=1)["accuracy"]
print('test accuracy: {}'.format(test_acc))

# 预测结果,记录数据
y_pred = dnn.predict(x=x_test, as_iterable=False)
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
cnf_matrix = confusion_matrix(y_test, y_pred)

# 对比数据,生成图片
cm_plot(y_test, y_pred)

    输出

test accuracy: 0.9765999913215637

真实标签(行)与预测标签(列)数量混淆矩阵
    
      欢迎各位大佬交流讨论
      再次感谢网友的混淆矩阵图片实现,函数地址:https://blog.csdn.net/qq7835144/article/details/88952512

本文示例参考《TensorFlow学习指南——深度学习系统构建详解》第七章第二节

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值