点击查看 (人工智能) 系列文章
mnist文件读取
目前 tenserflow 2.0版本已经取消了examples的模块,当然,还可以自己下载后将整个文件夹放在tensorflow目录下继续使用,不过官方已经推荐了新的用法,还是建议直接用新用法。
1.0 旧用法报错
2.0 推荐用法
从执行结果看到:
- tensorflow版本是2.6.0
- 从storage.googleapis.com下载文件用时2s
- 下载的训练数据有6万个,测试数据有1万个;其中训练图片样本为28*28数据,标签为1维数值
读取函数解析
tf.kera.datasets.mnist.load_data(path='')
参数Path:
本地缓存Mnist数据集(mnist.npz)的相对路径(~/.kera/datasets/),如果windows系统,则可能是(C:\\Users\\用户名\\)
返回值:
Numpy数组元组,具体为“(x_train, y_train),(x_test, y_test)”,即从前到后依次为训练集的图片和标签,测试集的图片和标签
下载本地缓存
数据显示
将前20个数据,使用4*5的矩阵显示出来:
源码
# author: suoxd123@126.com
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./mnist/dataset')
# author: suoxd123@126.com
import tensorflow as tf
print(tf.__version__)
# 读取数据
mnist = tf.keras.datasets.mnist
(imgTrain, labelTrain),(imgTest, labelTest) = mnist.load_data(path='mnist.npz')
# 显示数据基本信息
print(imgTrain.shape, imgTest.shape)
print(labelTrain.shape, labelTest.shape)
# author: suoxd123@126.com
import matplotlib.pyplot as plt
# 显示图片
fig = plt.figure()
for i in range(20):
plt.subplot(4,5, i+1)
plt.tight_layout() #自动适配子图尺寸
plt.imshow(imgTrain[i], cmap='binary')
plt.title("Label:{}".format(labelTrain[i]))
plt.xticks([]) # 删除坐标标记
plt.yticks([])