信心很重要,加油!!
目标:训练了一个神经网络模型来对运动鞋和衬衫等衣物的图像进行分类。
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
1.导入Fashion MNIST数据集
(train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
1.1此处有坑:网络下载不下来数据。可以直接在这个网站下载,一共4份数据。直接下载:分别点击图中四个蓝色Download即可下载。
https://www.worldlink.com.cn/en/osdir/fashion-mnist.html
**1.2下载的数据存放位置
get_file()是通过url下载文件,其中cache_dir: 存储缓存文件的位置,为 None 时默认为 Keras 目录.我的keras默认目录就是C:\Users\lee.keras。
当然你也可以指定路径。
下载后放C:\Users\lee.keras\datasets\fashion-mnist。
*
*2.探索下数据:都是些可爱的服装图片
2.1单张图片**
plt.figure()
plt.imshow(train_images[100])
plt.colorbar()
plt.grid(False)
2.2 前25张图片
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(train_images[i],cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[0]])
2.3像素的归一化
在将它们输入神经网络模型之前,将这些值缩放到0到1的范围。为此,将值除以255。以相同的方式预处理训练集和测试集非常重要:
train_images = train_images/255.0
test_images = test_images/255.0
3.训练网络
3.1搭建网络各层,明确节点数
该网络的第一层tf.keras.layers.Flatten将图像格式从二维数组(28 x 28像素)转换为一维数组(28 * 28 = 784像素)。可以将这一层看作是堆叠图像中的像素行并将它们排成一行。该层没有学习参数。它只会重新格式化数据。
像素展平后,网络由tf.keras.layers.Dense两层序列组成。这些是紧密连接或完全连接的神经层。第一Dense层有128个节点(或神经元)。第二层(也是最后一层)返回长度为10的logits数组。每个节点包含一个得分,该得分指示当前图像属于10类之一。
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10)
])
3.2编译
在准备训练模型之前,还需要进行一些其他设置。这些是在模型的编译步骤中添加的:
损失函数 -衡量训练期间模型的准确性。您希望最小化此功能,以在正确的方向上“引导”模型。
优化器 -这是基于模型看到的数据及其损失函数来更新模型的方式。
指标 -用于监视培训和测试步骤。以下示例使用precision,即正确分类的图像比例。
model.compile(optimizer='adam', l oss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
3.3模型训练
model.fit(train_images, train_labels, epochs=10)
4.模型评估
测试数据集的准确性略低于训练数据集的准确性。训练准确性和测试准确性之间的差距代表过度拟合。
test_loss,test_auc = model.evaluate(test_images,test_labels,verbose=2)
5.模型预测分类的概率
pro = probability_model.predict(test_images)
print(pro[0])
np.argmax(pro[0])
输出:9 表明大概率属于第9类。np.argmax返回最大值的索引。
参考文章,感谢感谢:
https://blog.csdn.net/weixin_43440726/article/details/103864652
https://tensorflow.google.cn/tutorials/keras/classification?hl=zh_cn