fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
#训练集有60000张图片,前5000张图片作为验证集,后55000作为训练集
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
#输入是一个image的numpy数组
#显示一张图片
def show_single_image(img_arr):
plt.imshow(img_arr, cmap="binary")
plt.show()
show_single_image(x_train[1])
def show_imgs(n_rows, n_cols, x_data, y_data, class_names):
#y_data里存的是类别名的索引
#显示多张图片
assert len(x_data) == len(y_data)
assert n_rows * n_cols < len(x_data)
#定义一张大图
plt.figure(figsize = (n_cols * 1.4, n_rows * 1.6))
for row in range(n_rows):
for col in range(n_cols):
#当前行*总列数+列的偏移量
index = n_cols * row + col
plt.subplot(n_rows, n_cols, index+1) #在大图上画一张子图
plt.imshow(x_data[index], cmap="binary",
interpolation = 'nearest')
plt.axis('off') #关闭坐标
plt.title(class_names[y_data[index]])
plt.show()
class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress',
'Coat', 'Sandal', 'Shirt', 'Sneaker',
'Bag', 'Ankle boot']
show_imgs(3, 5, x_train, y_train, class_names)
note:如标题所示,本文代码着重于数据读取与展示,并未包含模型构建与训练部分,如果您想要关注其它部分的代码,请参照主页的其他博客
本文使用的数据集是keras里的资源fashion_mnist数据集
plt.subplot函数
它的调用是这样子的:subplot(numbRow , numbCol ,plotNum ) or subplot(numbRow numbCol plotNum),注意:可以不用逗号分开直接写在一起也是对的;
numbRow是plot图的行数;numbCol是plot图的列数;plotNum是指第几行第几列的第几幅图 ;
举个例子,如果是subplot (2 ,2 ,1),那么这个figure就是个2*2的矩阵图,也就是总共有4个图,1就代表了第一幅图