一.引入库
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
# 查看相应版本信息
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, tf.keras:
print(module.__name__, module.__version__)
2.5.0
sys.version_info(major=3, minor=7, micro=0, releaselevel='final', serial=0)
matplotlib 3.0.2
numpy 1.19.5
pandas 0.23.4
sklearn 0.20.2
tensorflow 2.5.0
tensorflow.keras 2.5.0
二.导入数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
三.训练集和测试集 拆分
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
# 训练集拆分为训练集和验证集
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
(5000, 28, 28) (5000,)
(55000, 28, 28) (55000,)
(10000, 28, 28) (10000,)
四.展示图像
# 显示单张图片
def show_single_image(img_arr):
# cmap 定义颜色图谱,默认rgb,本来就是黑白图片 用二进制(binary)显示就行
plt.imshow(img_arr, cmap="binary")
plt.show()
show_single_image(x_train[0])
# 显示多张图片
def show_imgs(n_rows, n_cols, x_data, y_data, class_names):
assert len(x_data) == len(y_data)
assert n_rows * n_cols < len(x_data)
plt.figure(figsize=(n_cols * 1.4, n_rows * 1.6))
for row in range(n_rows):
for col in range(n_cols):
index = n_cols * row + col
plt.subplot(n_rows, n_cols, index + 1)
plt.imshow(x_data[index], cmap="binary", interpolation="nearest")
# 关闭坐标系
plt.axis("off")
plt.title(class_names[y_data[index]])
plt.show()
class_names = ["T-shirt", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt",
"Sneaker", "Bag", "Ankle boot"
]
show_imgs(3, 5, x_train, y_train, class_names)