方式2:本文将解释如何使用keras构建网络
1.定义输入
这里由于使用Keras因此输入tensor必须使用Input()进行转换
from keras.layers import Input
#假设输入的形状为[224,224,3]
X_input = [224,224,3]
X_input = Input(X_input)
2.搭建网络
def net(X_input,out_put):
'''
X_input:输入tensor的维度
out_put:输出的维度
'''
#填0卷积,注意作用
X = ZeroPadding2D(padding=(3,3))(X_input)
X = Conv2D(64,kernel_size=(7,7),strides=(2,2),name='c1')(X)
X = MaxPooling2D(pool_size=(3,3),strides=(2,2),name='mx_pp1')(X)
X = Activation('relu',name='Ac1')(X)
X = Conv2D(filters=128,kernel_size=(3,3),strides=(1,1),name='c2')(X)
X = BatchNormalization(name='Bn1')(X)
X = Activation('relu',name='Ac2')(X)
#均值池化
X = AveragePooling2D(name='AVG_pool_last')(X)
#把X拉长
X = Flatten()(X)
#全连接
X = Dense(out_put,activation='softmax',name='fc')(X)
#此处Model用来画net的图
My_model_png = Model(inputs=X_input,outputs=X,name='My_model_png')
return X,My_model_png
if __name__ == '__main__':
X_input = [224,224,3]
X_input = Input(X_input)
out_put = 5
X,My_model_png = net(X_input,out_put)
My_model_png.summary()
path_png = '图片存储路径,需构建出.png'
from keras.utils import plot_model
plot_model(My_model_png, to_file=path_png,show_shapes=True)
输出:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 230, 230, 3) 0
_________________________________________________________________
c1 (Conv2D) (None, 112, 112, 64) 9472
_________________________________________________________________
mx_pp1 (MaxPooling2D) (None, 55, 55, 64) 0
_________________________________________________________________
Ac1 (Activation) (None, 55, 55, 64) 0
_________________________________________________________________
c2 (Conv2D) (None, 53, 53, 128) 73856
_________________________________________________________________
Bn1 (BatchNormalization) (None, 53, 53, 128) 512
_________________________________________________________________
Ac2 (Activation) (None, 53, 53, 128) 0
_________________________________________________________________
AVG_pool_last (AveragePoolin (None, 26, 26, 128) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 86528) 0
_________________________________________________________________
fc (Dense) (None, 5) 432645
=================================================================
Total params: 516,485
Trainable params: 516,229
Non-trainable params: 256