简述Fashion Mnist
Fashion mnist 是一个衣物数据集,集成在keras中可以直接使用。本文记录了一步一步利用 Fashion minst 的数据库训练 tensorflow 神经网络。
Step1:导入所有的模块
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)
Step2:下载数据集
fashion_mnist = keras.datasets.fashion_mnist
(train_images , train_labels) , (test_images , test_labels) = fashion_mnist.load_data()
print("train images shapes:" + str(train_images.shape) + "test images shapes:" + str(test_images.shape))
fashion mnist是自带在keras中的,所以可以直接下载。拿到数据的时候,先对数据进行一个分割,分成训练集和测试集。此训练集中包含60000个数据集,每个数据集为28 * 28的图片。
Step3:查看一个单个的数据图片
plt.figure()
plt.imshow(train_images[0]) #show a picture
plt.colorbar() # show a color scale
plt.grid(True) # display the gtid
利用matplotlib来画出这张图片。imshow表示绘图,colorbar表示显示色阶,grid代表是否绘制网格。
Step4:数据归一化
数据的标准化(normalization)是将数据按比例缩放,使之落入一个小的特定区间。其主要目的在于将数据变化成[0 , 1]之间的数据,以方便后续对于数据的处理。
train_images = train_images / 255.0
test_images = test_images / 255.0
Step5:查看多组数据及其分类
class_names = ["T-shirt","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankie boot"]
plt.figure(figsize = (15 , 15))
for i in range(25):
plt.subplot(5 , 5, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i] , cmap = plt.get_cmap('PuBuGn'))
plt.xlabel(class_names[train_labels[i]])
此处为何不知道为什么我的不支持中文显示,所以只能换成了英文分类。