TensorFlow的MNIST学习

版权声明:【http://thinkgamer.cn】 https://blog.csdn.net/Gamer_gyt/article/details/80039242



打开微信扫一扫,关注微信公众号【数据与算法联盟】

转载请注明出处:http://blog.csdn.net/gamer_gyt
博主微博:http://weibo.com/234654758
Github:https://github.com/thinkgamer


MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.

数据集 目的
data_sets.train 55000 组 图片和标签, 用于训练。
data_sets.validation 5000 组 图片和标签, 用于迭代验证训练的准确性。
data_sets.test 10000 组 图片和标签, 用于最终测试训练的准确性。

数据集简介

MNIST数据集加载有两种办法,第一是直接从网上下载,第二是下载到本地进行load(跟第一种类似,只不过是事先下载好,从本地进行加载)。从网上下载到本地方式如下:

# 加载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
print("load finish")

mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
print(type(mnist))

输出为:

load finish
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>

print("MNIST 训练集数据条数:" ,mnist.train.num_examples)
print("MNIST 测试集数据条数:" ,mnist.test.num_examples)

train_img = mnist.train.images
train_label = mnist.train.labels
print("训练集类型:",type(train_img))
print("训练集维度:",train_img.shape)

test_img = mnist.test.images
test_label = mnist.test.labels
print("测试集类型:",type(test_img))
print("测试集维度:",test_img.shape)

输出为:

MNIST 训练集数据条数: 55000
MNIST 测试集数据条数: 10000
训练集类型: <class 'numpy.ndarray'>
训练集维度: (55000, 784)
测试集类型: <class 'numpy.ndarray'>
测试集维度: (10000, 784)

打开当前运行代码的目录,我们会发现一个MNIST_data的文件夹,里边包含的文件如下:

文件 内容
train-images-idx3-ubyte.gz 训练集图片 - 55000 张 训练图片, 5000 张 验证图片
train-labels-idx1-ubyte.gz 训练集图片对应的数字标签
t10k-images-idx3-ubyte.gz 测试集图片 - 10000 张 图片
t10k-labels-idx1-ubyte.gz 测试集图片对应的数字标签

使用next_batch函数加载指定条数的数据集

# 关于next_batch函数
batchSize = 100
batch_x,batch_y = mnist.train.next_batch(batch_size=batchSize)
print(batch_x.shape)
print(batch_y.shape)

输出为:

(100, 784)
(100, 10)



打开微信扫一扫,加入数据与算法交流大群

阅读更多

没有更多推荐了,返回首页