fashion_mnist分类模型的数据读取与显示

FashionMNIST 是一个替代 MNIST 手写数字集的图像数据集。 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。FashionMNIST 的大小、格式和训练集 / 测试集划分与原始的 MNIST 完全一致。60000/10000 的训练测试数据划分,28x28 的灰度图片。

经典的MNIST数据集包含了大量的手写数字。十几年来,来自机器学习、机器视觉、人工智能、深度学习领域的研究员们把这个数据集作为衡量算法的基准之一。我们会在很多的会议,期刊的论文中发现这个数据集的身影。实际上,MNIST数据集已经成为算法作者的必测的数据集之一。有人曾调侃道:“如果一个算法在MNIST不work, 那么它就根本没法用;而如果它在MNIST上work, 它在其他数据上也可能不work!”

制作这个数据集的目的就是取代MNIST,作为机器学习算法良好的“检测器”,用以评估各种机器学习算法。为什么不用MNIST了呢? 因为MNIST就现在的机器学习算法来说,是比较好分的,很多机器学习算法轻轻松松可以达到99%,因此无法区分出各类机器学习算法的优劣。

为了和MNIST兼容,Fashion-MNIST 与MNIST的格式,类别,数据量,train和test的划分,完全一致。

原文链接:https://blog.csdn.net/qq_36771850/article/details/86942434

 具体请见代码及其注释

# pip install pandas --ignore-installed pandas
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 不显示等级2以下的提示信息
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import sys
import time
from tensorflow import keras
import tensorflow as tf


print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module. __name__,module.__version__)


# MNIST是一个非常有名的  手写体数字识别  数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例
# 在直接使用这个数据集的时候,联网不能下载,控制台报错,所以根据他的提示,
# 将四个数据集都下载下来放到C:\Users\26225\.keras\datasets\fashion-mnist中,就能用了。
# 四个数据集:train-labels-idx1-ubyte.gz  train-images-idx3-ubyte.gz t10k-labels-idx1-ubyte.gz  t10k-images-idx3-ubyte.gz
fashion_mnist=keras.datasets.fashion_mnist
(x_train_all,y_train_all),(x_test,y_test)=fashion_mnist.load_data()#加载数据
# 将训练集拆分为为训练集和验证集
x_vaild,x_train=x_train_all[:5000],x_train_all[5000:]
y_vaild,y_train=y_train_all[:5000],y_train_all[5000:]
# shape在numpy中打印矩阵的行列数

print(x_vaild.shape,y_vaild.shape)
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape)
# 打印 验证集 训练集 测试集
# 5000张 尺寸28*28
# (5000, 28, 28) (5000,)
# (55000, 28, 28) (55000,)
# (10000, 28, 28) (10000,)

# 热力图是一种数据的图形化表示,具体而言,就是将二维数组中的元素用颜色表示。
# 热力图之所以非常有用,是因为它能够从整体视角上展示数据,更确切的说是数值型数据。
# 使用imshow()函数可以非常容易地制作热力图。
def show_single_image(img_arr):
     plt.imshow(img_arr,cmap="binary")#cmap定义颜色图谱,默认RGB,黑白图片用二进制显示就行
     plt.show()
#调用函数,打印训练集中图片
# show_single_image(x_train[0])


# 可视化一下图片以及对应的标签
# 展示多张图片
def show_imgs(n_rows,n_cols,x_data,y_data,class_names):

    # Python assert(断言)用于判断一个表达式,在表达式条件为 false的时候触发异常
    assert len(x_data)==len(y_data)# 验证x,y数据集相等
    assert n_rows*n_cols < len(x_data)# 验证行和列的乘机不能大于总样本数
    plt.figure(figsize=(n_cols*1.4,n_rows*1.6))#定义一张大图 尺寸为1.4 1.6
    #为每一行添加小图片
    for row in range(n_rows):
        for col in range(n_cols):
            index = n_cols*row+col #此小图片的索引值
            #subplot创建子图,划分为n_rows*n_cols分布的裂块,每一个裂块放上index+1的图片
            plt.subplot(n_rows,n_cols,index+1)
            # interpolation='nearest'当图片缩放时采用最近邻域内插法
            plt.imshow(x_data[index],cmap="binary",interpolation='nearest')
            plt.axis('off')#将小图坐标系关闭
            plt.title(class_names[y_data[index]]) #添加标题
    plt.show()

class_names =[' T- shirt', 'Trouser', ' Pullover','Dress' ,
    'Coat', ' Sandal',' Shirt', ' Sneaker',
    'Bag',' Ankle boot ']

show_imgs(5, 5,x_train, y_train,class_names)

 结果如图所示:

 

 

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
以下是使用TensorFlow实现时装分类的代码和注释: ``` python import tensorflow as tf from tensorflow import keras from keras.callbacks import ModelCheckpoint # 1. 读取数据 (train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data() # 2. 模型编写 model = keras.Sequential([ keras.layers.Flatten(input_shape=(28, 28)), # 展平层,将输入的28x28的图像展平为一维向量 keras.layers.Dense(128, activation=tf.nn.relu), # 第一层全连接层,128个神经元,激活函数为ReLU keras.layers.Dense(128, activation=tf.nn.relu), # 第二层全连接层,128个神经元,激活函数为ReLU keras.layers.Dense(10, activation=tf.nn.softmax) # 输出层,输出10个类别的概率分布 ]) # 3. 选择SGD优化器 sgd = keras.optimizers.SGD(lr=0.01) # 4. 训练和评估 model.compile(optimizer=sgd, loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 设置模型保存路径和文件名 checkpoint_path = "model/checkpoint.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) # 定义回调函数,每5个epoch保存一次模型 cp_callback = ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1, period=5) # 开始训练模型 model.fit(train_images, train_labels, epochs=50, validation_data=(test_images, test_labels), callbacks=[cp_callback]) # 5. 打印模型结构 model.summary() # 6. 保存模型 # 6.a. 保存成ckpt格式 model.save_weights(checkpoint_path.format(epoch=0)) # 6.b. 保存成h5格式 model.save('model/fashion_mnist.h5') ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值