defcnn(input_shape,classes):
model = Sequential()
model.add(Conv2D(input_shape=input_shape,filters=25,kernel_size=(3,3),padding='same',activation='relu'))
model.add(MaxPool2D())
model.add(Dropout(rate=0.3,name='student_feature1'))
model.add(Flatten())
model.add(Dense(32,activation='relu',name='student_feature2'))
model.add((Dense(classes,activation='softmax')))return model
model = cnn(input_shape=(28,28,1),classes=10)
将网络结构进行可视化展示
from keras.utils.vis_utils import plot_model
plot_model(cnn,to_file='model.png',show_shapes=True,show_layer_names=False)