lenet结构
Lenet-5代码
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 6 19:45:01 2018
@author: yuyangyg
"""
import keras
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
x_train = x_train / 255.
x_test = x_test / 255.
#CNN将输入变成[batch,channel,width,height]
x_train = x_train.reshape(-1, 28, 28, 1)
x_test=x_test.reshape(-1,28,28,1)
from keras.layers import Conv2D, MaxPool2D, Dense, Flatten
from keras.models import Sequential
lenet = Sequential()
lenet.add(Conv2D(6, kernel_size=3, strides=1, padding='same', input_shape=(28, 28, 1)))
lenet.add(MaxPool2D(pool_size=2, strides=2))
lenet.add(Conv2D(16, kernel_size=5, strides=1, padding='valid'))
lenet.add(MaxPool2D(pool_size=2, strides=2))
lenet.add(Flatten())#多维向量压成一维
lenet.add(Dense(120))
lenet.add(Dense(84))
lenet.add(Dense(10, activation='softmax'))
lenet.summary()
#from keras.utils import plot_model
#plot_model(lenet, to_file='lenet.png', show_shapes=True)
lenet.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy'])
lenet.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=[x_test, y_test])
lenet.save('myletnet.h5') #保存训练好的模型
https://www.jianshu.com/p/7a0a3eefeea4
https://github.com/SherlockLiao/lenet/blob/master/Lenet.ipynb