下载mnist.npz数据集
链接:https://pan.baidu.com/s/1f8q1HDlObIdgtU1kqN99zA?pwd=9xub
提取码:9xub
首先我们可以查看数据集标签和样本的数量:
这里使用numpy库读取数据通过使用shape函数查看数据的信息
import numpy as np
# 根据你存放的路径修改
path = "../data/mnist.npz"
data = np.load(path)
x_train ,y_train = data['x_train'] , data['y_train']
x_test ,y_test = data['x_test'] , data['y_test']
data.close()
print('样本数据的相关信息为:train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (x_train.shape, y_train.shape, x_test.shape, y_test.shape))
样本数据的相关信息为:train_x:(60000, 28, 28), train_y:(60000,), test_x:(10000, 28, 28), test_y:(10000,)
可以编写相关代码查看数据集中训练集和测试集中一部分数据对应的标签和图片。
我们可以定义函数进行展示。
import numpy as np
import matplotlib.pyplot as plt
def load_mnist():
# 根据你存放的路径修改
path = "../data/mnist.npz"
data = np.load(path)
x_train,y_train = data['x_train'],data['y_train']
x_test,y_test = data['x_test'],data['y_test']
data.close()
return (x_train,y_train),(x_test,y_test)
def main():
# 加载数据
(X_train,y_train_label),(test_image,test_label) = load_mnist()
# 设置4行数据,每行数据展示8个图片,前两行用于展示训练集的数据,后面两行可以展示测试集的数据
fig,ax = plt.subplots(nrows=4,ncols=8,sharex=True,sharey=True)
ax = ax.flatten()
# 输出训练集的前16个
for i in range(0,16):
# 获取训练集前16个数据
img = X_train[i].reshape(28,28)
# 获取测试集前16个数据
img2 = test_image[i].reshape(28,28)
# 前16个框对应训练集的前16个数据
ax[i].set_title(y_train_label[i])
# 设置图片为黑白的
ax[i].imshow(img,cmap='Greys',interpolation='nearest')
# 后16个框对应测试集的前16个数据
ax[i+16].set_title(test_label[i])
ax[i+16].imshow(img2,cmap='Greys',interpolation='nearest')
# 取消x轴和y轴的刻度
ax[0].set_xticks([])
ax[0].set_yticks([])
# 改变边框大小,并输出图片
plt.tight_layout()
plt.show()
if __name__ == '__main__':
main()
最终输出测试集与训练集的数据信息。