depth (int): 网络层数
num_classes (int): 预测类别数
Return:
model (Model): 模型
“”"
if (depth - 2) % 6 != 0:
raise ValueError(‘depth should be 6n+2’)
#超参数
num_filters = 16
num_res_blocks = int((depth - 2) / 6)
inputs = keras.layers.Input(shape=input_shape)
x = resnet_layer(inputs=inputs)
for stack in range(3):
for res_block in range(num_res_blocks):
strides = 1
if stack > 0 and res_block == 0:
strides = 2
y = resnet_layer(inputs=x,num_filters=num_filters,
strides=strides)
y = resnet_layer(inputs=y,num_filters=num_filters,
activation=None)
if stack > 0 and res_block == 0:
x = resnet_layer(inputs=x,
num_filters=num_filters,
kernel_size=1,
strides=strides,
activation=None,
batch_normalization=False)
x = keras.layers.add([x,y])
x = keras.layers.Activation(‘relu’)(x)
num_filters *= 2
x = keras.layers.AveragePooling2D(pool_size=8)(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Dense(num_classes,activation=‘softmax’,<