MNIST数据集使用详解

数据集下载网址:http://yann.lecun.com/exdb/mnist/
下载后无需解压,将其放在一个文件夹下即可:
在这里插入图片描述
数据说明:
数据集常被分为2~3个部分
训练集(train set):用来学习的一组例子,用来适应分类器的参数[即权重]
验证集(validation set):一组用于调整分类器参数(即体系结构,而不是权重)的示例,例如选择神经网络中隐藏单元的数量
测试集(test set):一组仅用于评估完全指定分类器的性能[泛化]的示例

读取方式:

import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist_data_folder="/MNIST_data"	#指定数据集所在的位置(见上图存放格式)
mnist=input_data.read_data_sets(mnist_data_folder,one_hot=True)	#读取mnist数据集,指定标签格式one_hot=True

#获取数据集的个数
train_nums=mnist.train.num_examples
validation_nums=mnist.validation.num_examples
test_nums=mnist.test.num_examples
print("MNIST训练数据集个数 %d"%train_nums)
print("MNIST验证数据集个数 %d"%validation_nums)
print("MNIST测试数据集个数 %d"%test_nums)

#获取数据值
train_data=mnist.train.images   #所有训练数据
val_data=mnist.validation.images    #(5000,784)
test_data=mnist.test.images
print("训练集数据大小:",train_data.shape)
print("一幅图像大小:",train_data[1].shape)
print("一幅图像的列表表示:\n",train_data[1])

#获取标签值
train_labels=mnist.train.labels     #(55000,10)
val_labels=mnist.validation.labels  #(5000,10)
test_labels=mnist.test.labels   #(10000,10)
print("训练集标签数组大小L: ",train_labels.shape)
print("一幅图像的标签大小: ",train_labels[1].shape)
print("一幅图像的标签值:",train_labels[1])

#批量获取数据和标签  使用 next_batch(batch_size) 
#注意使用改方式时数据是随机读取的,但在同一批次中,数据和标签位置是对应的
batch_size=100  #每次批量训练100幅图像
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
testbatch_xs,testbatch_ys=mnist.test.next_batch(batch_size)
print("使用mnist.train.next_batch(batch_size)批量读取样本")
print("批量随机读取100个样本,数据集大小= ",batch_xs.shape)
print("批量随机读取100个样本,标签集大小= ",batch_ys.shape)
print("批量随机读取100个测试样本,数据集大小= ",testbatch_xs.shape)
print("批量随机读取100个测试样本,标签集大小= ",testbatch_ys.shape)

#显示图像
plt.figure()
for i in range(10):
    im=train_data[i].reshape(28,28)	#训练数据集的第i张图,将其转化为28x28格式
    #im=batch_xs[i].reshape(28,28)	#该批次的第i张图
    plt.imshow(im)
    plt.pause(0.1)	#暂停时间
plt.show()

运行结果:

MNIST训练数据集个数 55000
MNIST验证数据集个数 5000
MNIST测试数据集个数 10000
训练集数据大小: (55000, 784)
一幅图像大小: (784,)
一幅图像的列表表示:
 [0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.12156864 0.5176471  0.9960785
 0.9921569  0.9960785  0.8352942  0.32156864 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.08235294
 0.5568628  0.91372555 0.98823535 0.9921569  0.98823535 0.9921569
 0.98823535 0.8745099  0.07843138 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.48235297 0.9960785  0.9921569  0.9960785
 0.9921569  0.87843144 0.7960785  0.7960785  0.8745099  1.
 0.8352942  0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.7960785  0.9921569  0.98823535 0.9921569  0.8313726  0.07843138
 0.         0.         0.2392157  0.9921569  0.98823535 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.16078432 0.95294124 0.87843144
 0.7960785  0.7176471  0.16078432 0.59607846 0.11764707 0.
 0.         1.         0.9921569  0.40000004 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.15686275 0.07843138 0.         0.
 0.40000004 0.9921569  0.19607845 0.         0.32156864 0.9921569
 0.98823535 0.07843138 0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.32156864 0.83921576
 0.12156864 0.4431373  0.91372555 0.9960785  0.91372555 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.24313727
 0.40000004 0.32156864 0.16078432 0.9921569  0.909804   0.9921569
 0.98823535 0.91372555 0.19607845 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.59607846 0.9921569  0.9960785
 0.9921569  0.9960785  0.9921569  0.9960785  0.91372555 0.48235297
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.59607846 0.98823535 0.9921569  0.98823535 0.9921569
 0.98823535 0.75294125 0.19607845 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.24313727
 0.7176471  0.7960785  0.95294124 0.9960785  0.9921569  0.24313727
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.15686275 0.6745098  0.98823535 0.7960785  0.07843138 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.08235294 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.7176471  0.9960785  0.43921572 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.24313727
 0.7960785  0.6392157  0.         0.         0.         0.
 0.         0.         0.         0.         0.2392157  0.9921569
 0.5921569  0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.08235294 0.83921576 0.75294125 0.
 0.         0.         0.         0.         0.         0.
 0.         0.04313726 0.8352942  0.9960785  0.5921569  0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.40000004 0.9921569  0.5921569  0.         0.         0.
 0.         0.         0.         0.         0.16078432 0.8352942
 0.98823535 0.9921569  0.43529415 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.16078432 1.
 0.8352942  0.36078432 0.20000002 0.         0.         0.12156864
 0.36078432 0.6784314  0.9921569  0.9960785  0.9921569  0.5568628
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.6745098  0.98823535 0.9921569
 0.98823535 0.7960785  0.7960785  0.91372555 0.98823535 0.9921569
 0.98823535 0.9921569  0.50980395 0.07843138 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.08235294 0.7960785  1.         0.9921569  0.9960785
 0.9921569  0.9960785  0.9921569  0.9568628  0.7960785  0.32156864
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.07843138 0.5921569  0.5921569  0.9921569  0.67058825 0.5921569
 0.5921569  0.15686275 0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.        ]

训练集标签数组大小L:  (55000, 10)
一幅图像的标签大小:  (10,)
一幅图像的标签值: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
使用mnist.train.next_batch(batch_size)批量读取样本
批量随机读取100个样本,数据集大小=  (100, 784)
批量随机读取100个样本,标签集大小=  (100, 10)
批量随机读取100个测试样本,数据集大小=  (100, 784)
批量随机读取100个测试样本,标签集大小=  (100, 10)

plots显示:
在这里插入图片描述

在学会怎样读取后,我们可以用一个简单的神经网络来测试一下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist_data_folder="/MNIST_data"
mnist=input_data.read_data_sets(mnist_data_folder,one_hot=True)


#创建两个占位符,x为输入网络的图像,y_为输入网络的图像类别
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

#权重初始化函数
def weight_variable(shape):
    #输出服从截尾正态分布的随机值
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

#偏置初始化函数
def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

#创建卷积op
#x 是一个4维张量,shape为[batch,height,width,channels]
#卷积核移动步长为1。填充类型为SAME,可以不丢弃任何像素点
def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME")

#创建池化op
#采用最大池化,也就是取窗口中的最大值作为结果
#x 是一个4维张量,shape为[batch,height,width,channels]
#ksize表示pool窗口大小为2x2,也就是高2,宽2
#strides,表示在height和width维度上的步长都为2
def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1],
                          strides=[1,2,2,1], padding="SAME")

#第1层,卷积层
#初始化W为[5,5,1,32]的张量,表示卷积核大小为5*5,第一层网络的输入和输出神经元个数分别为1和32
W_conv1 = weight_variable([5,5,1,32])
#初始化b为[32],即输出大小
b_conv1 = bias_variable([32])

#把输入x(二维张量,shape为[batch, 784])变成4d的x_image,x_image的shape应该是[batch,28,28,1]
#-1表示自动推测这个维度的size
x_image = tf.reshape(x, [-1,28,28,1])

#把x_image和权重进行卷积,加上偏置项,然后应用ReLU激活函数,最后进行max_pooling
#h_pool1的输出即为第一层网络输出,shape为[batch,14,14,1]
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

#第2层,卷积层
#卷积核大小依然是5*5,这层的输入和输出神经元个数为32和64
W_conv2 = weight_variable([5,5,32,64])
b_conv2 = weight_variable([64])

#h_pool2即为第二层网络输出,shape为[batch,7,7,1]
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

#第3层, 全连接层
#这层是拥有1024个神经元的全连接层
#W的第1维size为7*7*64,7*7是h_pool2输出的size,64是第2层输出神经元个数
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])

#计算前需要把第2层的输出reshape成[batch, 7*7*64]的张量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

#Dropout层
#为了减少过拟合,在输出层前加入dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

#输出层
#最后,添加一个softmax层
#可以理解为另一个全连接层,只不过输出时使用softmax将网络输出值转换成了概率
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])

y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

#预测值和真实值之间的交叉墒
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

#train op, 使用ADAM优化器来做梯度下降。学习率为0.0001
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

#评估模型,tf.argmax能给出某个tensor对象在某一维上数据最大值的索引。
#因为标签是由0,1组成了one-hot vector,返回的索引就是数值为1的位置
correct_predict = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))

#计算正确预测项的比例,因为tf.equal返回的是布尔值,
#使用tf.cast把布尔值转换成浮点数,然后用tf.reduce_mean求平均值
accuracy = tf.reduce_mean(tf.cast(correct_predict, "float"))


train_data=mnist.train.images
train_labels=mnist.train.labels
test_data=mnist.test.images
test_labels=mnist.test.labels

batch_size=100  #每次批量训练100幅图像
batch_xs,batch_ys=mnist.train.next_batch(batch_size)    #随机抓取训练数据中的100个批处理数据点
test_xs,test_ys=mnist.test.next_batch(batch_size)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) #初始化变量
    for i in range(2000):  #开始训练模型,循环2000次,每次传入一张图像
        sess.run(train_step,feed_dict={x:[train_data[i]], y_:[train_labels[i]], keep_prob:0.5})
        if(i%100==0):   #每100次,传入一个批次的测试数据,计算其正确率
            print(sess.run(accuracy, feed_dict={x: test_xs, y_: test_ys, keep_prob: 1.0}))

"""
也可以批量导入训练,注意使用mnist.train.next_batch(batch_size),得到的批次数据每次都会自动随机抽取这个批次大小的数据
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) #初始化变量
    for i in range(200):  #开始训练模型,循环200次,每次传入一个批次的图像
        sess.run(train_step,feed_dict={x:batch_xs, y_:batch_ys, keep_prob:0.5})
        if(i%20==0):   #每20次,传入一个批次的测试数据,计算其正确率
            print(sess.run(accuracy, feed_dict={x: test_xs, y_: test_ys, keep_prob: 1.0}))
"""

运行结果:

Extracting /MNIST_data\train-images-idx3-ubyte.gz
Extracting /MNIST_data\train-labels-idx1-ubyte.gz
Extracting /MNIST_data\t10k-images-idx3-ubyte.gz
Extracting /MNIST_data\t10k-labels-idx1-ubyte.gz
0.12
0.4
0.54
0.54
0.61
0.7
0.73
0.84
0.77
0.8
0.8
0.86
0.88
0.85
0.88
0.88
0.85
0.92
0.89
0.84

参考:
train set、 validation set 、test set三者的概念
MNIST手写数字数据集读取方法

  • 15
    点赞
  • 109
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值