MNIST数据集输出识别准确率用到的函数方法

MNIST数据集输出识别准确率

MINST数据集:

提供6w张28*28像素点的0~9手写数字图片和标签,用于训练;

提供1w张28*28像素点的0~9手写数字图片和标签,用于测试。
在这里插入图片描述
每张图片的784个像素点(28*28=784)组成长度为784的一维数组,作为输入特征

eg:[0. 0. 0. 0. 0. 0. 0.380 0.376 0.301 0.462 … … … 0.239 0. 0. 0. 0. 0. 0. 0. 0.]

图片的标签以一维数组形式给出,每个元素表示对应分类出现的概率。

eg:[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]这个表示数字6的概率是100%其他数字的概率均是0,意思就是图片应该是阿拉伯数字6。

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('./data/', one_hot=True)

Train and Validation:→训练和验证模型参数

  • train-images-idx3-ubyte.gz
  • train-labels-idx1-ubyte.gz

Test:→测试模型

  • t10k-images-idx3-ubyte.gz
  • t10k-labels-idx1-ubyte.gz

返回各子集样本数:

mnist.train.num_examples
mnist.validation.num_examples
mnist.test.num_examples

返回标签和数据:

>>>mnist.train.labels[0]   # 表示第0张图片的标签
array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
>>>mnist.train.images[0]   # 表示第0张图片的784个像素点
array([0.             , 0.              ,0.
       0.             , 0.              ,0.
       0.             , 0.              ,0.
       0.             , 0.              ,0.
        ..................................
])
# 总共784个像素点

收一小撮数据,准备喂入神经网络训练:

>>>BATCH_SIZE = 200                   # 定义一小撮是多少
>>>xs, ys = mnist.train.next_batch(BATCH_SIZE)  # 从训练集中随机抽取BATCH_SIZE组个数据和标签
>>>print("xs shape:", xs.shape)
xs.shape: (200,784)   # 200行数据 每个数据有784个像素点
>>>print("ys shape:", ys.shape)
ys shape: (200,10)    # 200行数据 每个数据有10个元素是输出的分类

下面是一些常用的函数:

tf.get_collection("")        # 从集合中取全部变量,生成一个列表
tf.add_n([])                 # 列表内对应元素相加
tf.cast(x,dtype)			 # 把x转为dtype类型
tf.argmax(x,axis)			 # 返回最大值所在的索引号,如:tf.argmax([1,0,0],1) 返回0
os.path.join("home", "name") # 返回home/name
字符串.spilt()           	   # 按指定字符串对字符串进行切片,返回分割后的列表
with tf.Graph().as_default() as g:  # 其内定义的节点在计算图g

保存模型:

saver = tf.train.Saver()    # 实例化saver对象
with tf.Session as sess:
	for i in range(STEPS):
        if i % 轮数 == 0:
            saver.save(sess, os.path.join(MODEL_SAVE_PATH,MODEL_NAME), global_step=global_step)

加载模型:

with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state(存储路径)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

实例化可还原滑动平均值的saver

ema = tf.train.ExponentialMovingAverage(滑动平均基数)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)

准确率计算方法:

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值