#coding:utf-8
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.layers import Convolution2D,MaxPooling2D,Flatten,Dense,Dropout
from keras.optimizers import Adam
from keras.utils import plot_model
# MNIST数据存放的路径
file = "../data/MNIST"
# 导入数据
mnist = input_data.read_data_sets(file, one_hot=True)
X_train = mnist.train.images
y_train = mnist.train.labels
X_train = X_train.reshape((-1,28,28,1))
X_test = mnist.test.images
y_test = mnist.test.labels
X_test = X_test.reshape((-1,28,28,1))
model = Sequential()
#卷积1,激活
model.add(Convolution2D(input_shape=(28,28,1),filters=30,kernel_size=5,strides=1,padding='same',activation='relu'))
#池化
model.add(MaxPooling2D(pool_size=(2,2),strides=1,padding='same'))
#卷积2,激活
model.add(Convolution2D(filters=30,kernel_size=5,strides=1,padding='same',activation='relu'))
#池化
model.add(MaxPooling2D(pool_size=(2,2),strides=1,padding='same'))
#Flatten , 转成一维向量
model.add(Flatten())
#全连接1
model.add(Dense(500,activation='relu'))
#Dropout 0.5
model.add(Dropout(0.5))
#全连接2
model.add(Dense(10,activation='softmax'))
#模型结构保存为图片
plot_model(model, to_file='mnist_cnn.png', show_shapes=True)
#优化器
adam = Adam(lr=1e-4)
#编译
model.compile(optimizer=adam,loss="categorical_crossentropy",metrics=['accuracy'])
#训练
model.fit(X_train,y_train,batch_size=100,epochs=10)
#评估效果
loss,acc = model.evaluate(X_test,y_test)
print(loss,acc)
CNN手写字母识别-Keras版本
最新推荐文章于 2022-07-28 14:39:06 发布