10.2.2 使用顺序API构建图像分类器
import tensorflow as tf
使用keras加载数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train_full,Y_train_full),(X_test,Y_test) = fashion_mnist.load_data()
X_train_full.shape
X_train_full.dtype
dtype('uint8')
X_valid,X_train = X_train_full[:5000]/255.0,X_train_full[5000:]/255.0
Y_valid,Y_train = Y_train_full[:5000],Y_train_full[5000:]
calass_names=["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot"]
calass_names[Y_train[0]]
'Coat'
#显示5张数据集中的图片
import os
import numpy as np
import matplotlib.pyplot as plt
def plotImages(images_arr):
fig, axes = plt.subplots(1, 5, figsize=(20, 20))
axes = axes.flatten()
for img, ax in zip(images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
plotImages(X_train_full[:5])
建立神经网络
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=[28,28]))
model.add(tf.keras.layers.Dense(300,activation='relu'))
model.add(tf.keras.layers.Dense(100,activation='relu'))
model.add(tf.keras.layers.Dense(10,activation='softmax'))
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
_________________________________________________________________
dense (Dense) (None, 300) 235500
_________________________________________________________________
dense_1 (Dense) (None, 100) 30100
_________________________________________________________________
dense_2 (Dense) (None, 10) 1010
=================================================================
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________
获取层列表
model.layers
[<tensorflow.python.keras.layers.core.Flatten at 0x181727ed0>,
<tensorflow.python.keras.layers.core.Dense at 0x18174b510>,
<tensorflow.python.keras.layers.core.Dense at 0x1816c6dd0>,
<tensorflow.python.keras.layers.core.Dense at 0x1819db690>]
hidden1 = model.layers[1]
hidden1.name
'dense'
编译模型
model.compile(loss="sparse_categorical_crossentropy",
optimizer='sgd',
metrics='accuracy')
训练和评估模型
history = model.fit(X_train,Y_train,epochs=30,
validation_data=(X_valid,Y_valid))
Epoch 1/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.9727 - accuracy: 0.6985 - val_loss: 0.5125 - val_accuracy: 0.8194
Epoch 2/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.4996 - accuracy: 0.8286 - val_loss: 0.4797 - val_accuracy: 0.8380
Epoch 3/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.4465 - accuracy: 0.8449 - val_loss: 0.4384 - val_accuracy: 0.8422
Epoch 4/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.4217 - accuracy: 0.8543 - val_loss: 0.3948 - val_accuracy: 0.8624
Epoch 5/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3968 - accuracy: 0.8602 - val_loss: 0.3886 - val_accuracy: 0.8608
Epoch 6/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3769 - accuracy: 0.8689 - val_loss: 0.3790 - val_accuracy: 0.8690
Epoch 7/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3710 - accuracy: 0.8683 - val_loss: 0.3654 - val_accuracy: 0.8710
Epoch 8/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3573 - accuracy: 0.8722 - val_loss: 0.3590 - val_accuracy: 0.8738
Epoch 9/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3512 - accuracy: 0.8748 - val_loss: 0.3454 - val_accuracy: 0.8796
Epoch 10/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3366 - accuracy: 0.8803 - val_loss: 0.3636 - val_accuracy: 0.8720
Epoch 11/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3282 - accuracy: 0.8825 - val_loss: 0.3608 - val_accuracy: 0.8692
Epoch 12/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3149 - accuracy: 0.8870 - val_loss: 0.3322 - val_accuracy: 0.8844
Epoch 13/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3131 - accuracy: 0.8882 - val_loss: 0.3289 - val_accuracy: 0.8816
Epoch 14/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.3071 - accuracy: 0.8903 - val_loss: 0.3529 - val_accuracy: 0.8742
Epoch 15/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2989 - accuracy: 0.8929 - val_loss: 0.3251 - val_accuracy: 0.8808
Epoch 16/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2921 - accuracy: 0.8932 - val_loss: 0.3413 - val_accuracy: 0.8790
Epoch 17/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2913 - accuracy: 0.8960 - val_loss: 0.3787 - val_accuracy: 0.8528
Epoch 18/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2789 - accuracy: 0.8999 - val_loss: 0.3294 - val_accuracy: 0.8796
Epoch 19/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2732 - accuracy: 0.9004 - val_loss: 0.3240 - val_accuracy: 0.8810
Epoch 20/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2653 - accuracy: 0.9047 - val_loss: 0.3100 - val_accuracy: 0.8898
Epoch 21/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2701 - accuracy: 0.9028 - val_loss: 0.3059 - val_accuracy: 0.8890
Epoch 22/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2556 - accuracy: 0.9085 - val_loss: 0.3142 - val_accuracy: 0.8808
Epoch 23/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2505 - accuracy: 0.9088 - val_loss: 0.3051 - val_accuracy: 0.8910
Epoch 24/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2472 - accuracy: 0.9122 - val_loss: 0.3039 - val_accuracy: 0.8902
Epoch 25/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2489 - accuracy: 0.9109 - val_loss: 0.3015 - val_accuracy: 0.8906
Epoch 26/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2434 - accuracy: 0.9126 - val_loss: 0.2960 - val_accuracy: 0.8962
Epoch 27/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2436 - accuracy: 0.9138 - val_loss: 0.2946 - val_accuracy: 0.8914
Epoch 28/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2351 - accuracy: 0.9147 - val_loss: 0.3031 - val_accuracy: 0.8902
Epoch 29/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2292 - accuracy: 0.9166 - val_loss: 0.2965 - val_accuracy: 0.8920
Epoch 30/30
1719/1719 [==============================] - 2s 1ms/step - loss: 0.2233 - accuracy: 0.9181 - val_loss: 0.3133 - val_accuracy: 0.8880
创建pandas DataFrame并调用plot()
import pandas as pd
import matplotlib.pyplot as plt
pd.DataFrame(history.history).plot(figsize=(8,5))
plt.gca().set_ylim(0,1)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8heIuDMA-1605594381775)(output_21_0.png)]
使用模型进行预测
使用predict()方法进行新实例的预测
X_new = X_test[:3]
y_proda = model.predict(X_new)
y_proda.round(2)
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
plotImages(X_test[:5])
![在这里插入图片描述](https://img-blog.csdnimg.cn/20201117142730581.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3BpZ3llbGxvdzk4,size_16,color_FFFFFF,t_70#pic_center)
如果关心的是估计概率最高的类 可以使用
pre dict_classes()方法
y_pred = model.predict_classes(X_new)
y_pred ###这里提示`model.predict_classes()`方法已经弃用,可以使用`np.argmax(model.predict(x), axis=-1)`的方法
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py:450: UserWarning: `model.predict_classes()` is deprecated and will be removed after 2021-01-01. Please use instead:* `np.argmax(model.predict(x), axis=-1)`, if your model does multi-class classification (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`, if your model does binary classification (e.g. if it uses a `sigmoid` last-layer activation).
warnings.warn('`model.predict_classes()` is deprecated and '
array([9, 2, 1])
np.array(calass_names)[y_pred] ###结果可以看到对三个图像进行了正确分类
array(['Ankle boot', 'Pullover', 'Trouser'], dtype='<U11')
使用np.argmax(model.predict(x), axis=-1)
的方法
y_pred1=np.argmax(model.predict(X_new), axis=-1)
y_pred1
array([9, 2, 1])
预测结果正确