利用TensorFlow进行Fashion MNIST数据集的基本分类问题
图像标签及对应的类别
label | class |
---|---|
0 | T-shirt/top(T-恤) |
1 | Trouser (裤子) |
2 | Pullover (套衫) |
3 | Dress(连衣裙) |
4 | Coat(大衣) |
5 | Sandal (凉鞋) |
6 | Skirt(衬衫) |
7 | Sneaker(运动鞋) |
8 | Bag(手提包) |
9 | ankle boot(踝靴) |
1、导入包
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
2、载入fashion_mnist的数据集
图像是28x28NumPy数组,像素值在0到255之间。标签是一个整数数组,范围从0到9
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()
#补充每个类别的名称,组成一个list
class_names=['T-shirt/top','Trouser','Pullover','Dress','Coat', 'Sandal','Shirt','Sneaker','Bag','Ankle boot']
#输出训练集的特征的大小
print(train_images.shape)
#输出训练集的标签的大小
print(len(train_labels))
print(test_images.shape)
print(len(test_labels))
3、可视化第一个样本
#图像像素值位于 0 到 255 之间
plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.gca().grid(False)
plt.show()
4、预处理并可视化训练集前25张图像
归一化:对数据进行预处理,将其缩放,使所有像素值都在[0,1]区间内
train_images=train_images/255.0
test_images=test_images/255.0
#可视化
plt.figure(figsize=(10,10))#总体长和宽
#显示训练集中的前25张图像,并在每个图像下方显示类名
for i in range(25