读取Mnist数据
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
构建网络
model = keras.Sequential([...])
或者
model = keras.Sequential()
model.add(...)
网络详解
神经网络层
keras.layers.Dense(84, activation='sigmoid'),
卷积层
keras.layers.Conv2D(96, kernel_size=(11, 11), strides=(4, 4), activation='relu', use_bias=True, padding='valid'),
池化层
keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same'),
或者
keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
全连接层
keras.layers.Flatten(),
编译
model.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
训练
records = model.fit(x_train, label_train, epochs=1, validation_split=0.2, )
保存模型
model.save('my_modle.h5')
加载模型
model = keras.models.load_model('my_modle.h5')
不定时更新…