5、编译模型
由于这是一个二分类问题且模型输出概率值一个使用 sigmoid 激活函数的单一单元层,优化器optimizer选择的adam梯度下降优化算法,使用 binary_crossentropy
损失函数
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
6、训练模型
x_val = test_data[15000:]
partial_x_train = train_data
y_val = test_labels[15000:]
partial_y_train = train_labels
history = model.fit(partial_x_train,
partial_y_train,
epochs=20,
batch_size=512,
validation_data=(x_val, y_val),
verbose=1)
打印结果:
Epoch 1/20
49/49 [==============================] - 4s 61ms/step - loss: 0.6906 - accuracy: 0.5580 - val_loss: 0.6891 - val_accuracy: 0.5781
Epoch 2/20
49/49 [==============================] - 3s 55ms/step - loss: 0.6734 - accuracy: 0.7336 - val_loss: 0.6407 - val_accuracy: 0.8160
Epoch 3/20
49/49 [==============================] - 3s 63ms/step - loss: 0.6270 - accuracy: 0.7878 - val_loss: 0.5680 - val_accuracy: 0.8354
Epoch 4/20
49/49 [==============================] - 3s 61ms/step - loss: 0.5545 - accuracy: 0.8280 - val_loss: 0.5425 - val_accuracy: 0.7697
Epoch 5/20
49/49 [==============================] - 3s 59ms/step - loss: 0.4733 - accuracy: 0.8551 - val_loss: 0.4434 - val_accuracy: 0.8446
Epoch 6/20
49/49 [==============================] - 3s 57ms/step - loss: 0.4008 - accuracy: 0.8775 - val_loss: 0.3876 - val_accuracy: 0.8632
Epoch 7/20
49/49 [==============================] - 3s 59ms/step - loss: 0.3433 - accuracy: 0.8950 - val_loss: 0.3358 - val_accuracy: 0.8820
Epoch 8/20
49/49 [==============================] - 3s 58ms/step - loss: 0.2991 - accuracy: 0.9064 - val_loss: 0.3309 - val_accuracy: 0.8725
Epoch 9/20
49/49 [==============================] - 3s 57ms/step - loss: 0.2652 - accuracy: 0.9154 - val_loss: 0.3022 - val_accuracy: 0.8844
Epoch 10/20
49/49 [==============================] - 3s 57ms/step - loss: 0.2373 - accuracy: 0.9251 - val_loss: 0.3302 - val_accuracy: 0.8609
Epoch 11/20
49/49 [==============================] - 3s 57ms/step - loss: 0.2139 - accuracy: 0.9329 - val_loss: 0.2858 - val_accuracy: 0.8864
Epoch 12/20
49/49 [==============================] - 3s 57ms/step - loss: 0.1940 - accuracy: 0.9406 - val_loss: 0.2993 - val_accuracy: 0.8734
Epoch 13/20
49/49 [==============================] - 3s 57ms/step - loss: 0.1768 - accuracy: 0.9461 - val_loss: 0.2604 - val_accuracy: 0.8957
Epoch 14/20
49/49 [==============================] - 3s 57ms/step - loss: 0.1613 - accuracy: 0.9526 - val_loss: 0.2797 - val_accuracy: 0.8850
Epoch 15/20
49/49 [==============================] - 3s 57ms/step - loss: 0.1475 - accuracy: 0.9580 - val_loss: 0.2649 - val_accuracy: 0.8919
Epoch 16/20
49/49 [==============================] - 3s 57ms/step - loss: 0.1353 - accuracy: 0.9617 - val_loss: 0.2750 - val_accuracy: 0.8871
Epoch 17/20
49/49 [==============================] - 3s 58ms/step - loss: 0.1240 - accuracy: 0.9662 - val_loss: 0.2863 - val_accuracy: 0.8825
Epoch 18/20
49/49 [==============================] - 3s 57ms/step - loss: 0.1140 - accuracy: 0.9693 - val_loss: 0.2697 - val_accuracy: 0.8901
Epoch 19/20
49/49 [==============================] - 3s 59ms/step - loss: 0.1050 - accuracy: 0.9726 - val_loss: 0.2805 - val_accuracy: 0.8859
Epoch 20/20
49/49 [==============================] - 3s 59ms/step - loss: 0.0965 - accuracy: 0.9752 - val_loss: 0.2825 - val_accuracy: 0.8860
7、评估模型
results = model.evaluate(test_data[:15000], test_labels[:15000], verbose=2)
print(results)
打印结果:
469/469 - 1s - loss: 0.2995 - accuracy: 0.8801 - 572ms/epoch - 1ms/step
[0.2995178997516632, 0.880133330821991]
8、创建一个准确率(accuracy)和损失值(loss)随时间变化的图表
model.fit()
返回一个 History
对象,该对象包含一个字典,其中包含训练阶段中loss、 accuracy、 val_loss、val_accuracy等。
history_dict = history.history
history_dict.keys()
import matplotlib.pyplot as plt
acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']
epochs = range(1, len(acc) + 1)
# “bo”代表 "蓝点"
plt.plot(epochs, loss, 'bo', label='Training loss')
# b代表“蓝色实线”
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
plt.clf() # 清除数字
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
![Training and validation loss](https://i-blog.csdnimg.cn/blog_migrate/51d255b3b4428b69a4bbe37719c72f22.png)
![](https://i-blog.csdnimg.cn/blog_migrate/a3814c76c42c54e61a2129336bcc6c8f.png)