model.py
from keras.layers import Activation, Convolution2D, Dropout, Dense, Flatten
from keras.layers.advanced_activations import PReLU
from keras.layers import AveragePooling2D, BatchNormalization
from keras.models import Sequential
def simple_CNN(input_shape, num_classes):
model = Sequential()
model.add(Convolution2D(16, 7, 7, border_mode='same',
input_shape=input_shape))
model.add(PReLU())
model.add(BatchNormalization())
model.add(AveragePooling2D(pool_size=(5, 5),strides=(2, 2), border_mode='same'))
model.add(Dropout(.5))
model.add(Convolution2D(32, 5, 5, border_mode='same'))
model.add(PReLU())
model.add(BatchNormalization())
model.add(AveragePooling2D(pool_size=(3, 3),strides=(2, 2), border_mode='same'))
model.add(Dropout(.5))
model.add(Convolution2D(32, 3, 3, border_mode='same'))
model.add(PReLU())
model.add(BatchNormalization())
model.add(AveragePooling2D(pool_size=(3, 3),strides=(2, 2), border_mode='same'))
model.add(Dropout(.5))
model.add(Flatten())
model.add(Dense(1028))
model.add(PReLU())
model.add(Dropout(0.5))
model.add(Dense(1028))
model.add(PReLU())
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))
return model
if __name__ == "__main__":
input_shape = (64, 64, 1)
num_classes = 7
model = simple_CNN((48, 48, 1), num_classes)
model.summary()
CNN网络结构
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 48, 48, 16) 800
_________________________________________________________________
p_re_lu_1 (PReLU) (None, 48, 48, 16) 36864
_________________________________________________________________
batch_normalization_1 (Batch (None, 48, 48, 16) 64
_________________________________________________________________
average_pooling2d_1 (Average (None, 24, 24, 16) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 24, 24, 16) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 24, 24, 32) 12832
_________________________________________________________________
p_re_lu_2 (PReLU) (None, 24, 24, 32) 18432
_________________________________________________________________
batch_normalization_2 (Batch (None, 24, 24, 32) 128
_________________________________________________________________
average_pooling2d_2 (Average (None, 12, 12, 32) 0
_________________________________________________________________
dropout_2 (Dropout) (None, 12, 12, 32) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 12, 12, 32) 9248
_________________________________________________________________
p_re_lu_3 (PReLU) (None, 12, 12, 32) 4608
_________________________________________________________________
batch_normalization_3 (Batch (None, 12, 12, 32) 128
_________________________________________________________________
average_pooling2d_3 (Average (None, 6, 6, 32) 0
_________________________________________________________________
dropout_3 (Dropout) (None, 6, 6, 32) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 1152) 0
_________________________________________________________________
dense_1 (Dense) (None, 1028) 1185284
_________________________________________________________________
p_re_lu_4 (PReLU) (None, 1028) 1028
_________________________________________________________________
dropout_4 (Dropout) (None, 1028) 0
_________________________________________________________________
dense_2 (Dense) (None, 1028) 1057812
_________________________________________________________________
p_re_lu_5 (PReLU) (None, 1028) 1028
_________________________________________________________________
dropout_5 (Dropout) (None, 1028) 0
_________________________________________________________________
dense_3 (Dense) (None, 7) 7203
_________________________________________________________________
activation_1 (Activation) (None, 7) 0
=================================================================
Total params: 2,335,459
Trainable params: 2,335,299
Non-trainable params: 160
_________________________________________________________________