keras+mnist+LeNet_5实现
具体实现步骤:
- 加载Mnist数据集,可视化训练集样本数据
- 对数据集进行标准化预处理(归一化和独热编码处理)
- keras构建LeNet_5经典模型,打印模型参数
- 模型反向编译,模型训练
- 模型训练过程acc和loss可视化展示
import keras
from keras import datasets
from keras.layers import Conv2D,MaxPool2D,Dense,Flatten,Activation
from keras.models import Sequential
from keras.optimizers import SGD
import numpy as np
from matplotlib import pyplot as plt
# 1. 数据集处理
# 加载数据,输出样本和标签数据
(X_train,Y_trian),(X_test,Y_test) = datasets.mnist.load_data()
print("训练集:",X_train.shape)
print("训练集标签",Y_trian.shape)
print("测试集",X_test.shape)
print("测试集标签",Y_test.shape)
# 2可视化图像
plt.figure(figsize=(15,15))
plt.title('mnist')
for i in range (20):
plt.subplot(4, 5, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(X_test[i])
plt.xlabel(Y_test[i])
plt.savefig('../TermimalTest2020/Images/MnistDate.jpg')
plt.show()
# 手写体种类,0~9 一共10个类别
num_class = 10
# 3图像数据标准化处理
X_train = np.reshape(X_train,[-1,28,28,1])
X_test = np.reshape(X_test ,[-1,28,28,1])
X_train = X_train/255.0
X_test = X_test/255.0
# 独热编码处理
Y_trian = keras.utils.np_utils.to_categorical(Y_trian,num_classes=num_class)
Y_test = keras.utils.np_utils.to_categorical(Y_test,num_classes=num_class)
print(Y_trian[0])
# 4. 模型LeNET-5模型
model = Sequential([
keras.layers.Conv2D(6, (5,5), activation='tanh', input_shape=(28,28,1),padding='same'),
keras.layers.MaxPool2D((2,2),strides=(1,1)),
keras.layers.Conv2D(16, (5,5), activation='tanh',padding='same'),
keras.layers.MaxPool2D((2,2),strides=(1,1)),
keras.layers.Flatten(),
keras.layers.Dense(120, activation='tanh'),
keras.layers.Dense(84, activation='tanh'),
keras.layers.Dense(num_class),
keras.layers.Activation('softmax')
])
model.summary()
# 5. 数据反向传播
model.compile(loss='categorical_crossentropy',
optimizer=SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True),
metrics=['acc'])
#保存权值模型
model.save_weights('../models/mnist.h5')
# 6. 数据训练
epochs = 10
batch_size =32
result=model.fit(x=X_train,y=Y_trian,epochs=epochs,batch_size=batch_size,verbose=2,
validation_split=0.2)
# 模型训练可视化
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs + 1), result.history["loss"])
plt.plot(range(1, epochs + 1), result.history["val_loss"])
plt.title("loss")
plt.xlabel("epochs")
plt.legend()
plt.ylabel("val/train loss")
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs + 1), result.history["acc"])
plt.plot(range(1, epochs + 1), result.history["val_acc"])
plt.title("acc")
plt.xlabel("epochs")
plt.ylabel("val/train acc")
plt.legend()
plt.show()
plt.close()