4.Keras—手写数字识别
Keras内置了一些数据集如mnist、cifar10、cifar100
下面我们用keras来跑一个手写数字识别项目MNIST,
我们按照前面讲的三模块构建网络法
第一模块准备数据集
1导入数据集
#导入mnist
from keras.datasets import mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
我们第一次导入模块的时候keras会自动下载mnist数据集,只有11M左右
2引入keras还有matplotlib(一会做可视化)
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers.core import Dense,Dropout
from keras.optimizers import SGD
from keras.utils import np_utils
3对数据进行处理
这里对数据进行处理
#变换
x_train = x_train.reshape(60000,784)
x_test = x_test.reshape(10000,784)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /=255
x_test /=255
#这里除以255是与RGB(0,256)对应
第四步划分训练集与测试集
from sklearn.model_selection import train_test_split
x_train,x_val,y_train,y_val = train_test_split(x_train,y_train)
plt.imshow(x_train[4].reshape(28,28))
plt.imshow(x_val[1].reshape(28,28))
plt.show()#查看训练集
数据准备完成
第二模块网络模型构建
#模型构建
#序贯模型
model = Sequential()
model.add(Dense(512,activation='relu',input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512,activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10,activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer=SGD(lr=0.001),
metrics=['accuracy'])
model.summary()#查看模型信息
第三步训练与测试
#开始训练
network_history = model.fit(x_train,y_train,batch_size=128,epochs=20,verbose=1,validation_data=(x_val,y_val))
#acc最终0.8516,val_acc:0.8816
我们通过可视化观察
#打印中间数据
def plot_history(network_history):
plt.figure()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.plot(network_history.history['loss'])
plt.plot(network_history.history['val_loss'])
plt.legend(['Training','Validation'])
plt.show()
plt.figure()
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.plot(network_history.history['acc'])
plt.plot(network_history.history['val_acc'])
plt.legend(['Training','Validation'],loc='lower right')
plt.show()
plot_history(network_history)#做可视化
发现效果很好也没有过拟合,准确率也比较高