定义网络结构
batch_size = 32
def CNNmodel():
input_layer=Input(shape=(224,224,3))
dense_model=DenseNet121(include_top=False,weights="imagenet",input_tensor=input_layer)
dense=dense_model(input_layer)
top_model=layers.GlobalAveragePooling2D()(dense)
top_model = layers.Dense(1)(top_model)
model=Model(inputs=input_layer,outputs=top_model)
for layer in dense_model.layers:
layer.trainable = False
return model
model = CNNmodel()
model.summary()
编译模型并训练
model.compile(optimizer=optimizers.Adam(lr=0.001,beta_1=0.9,beta_2=0.999,epsilon=1e-8),
loss='mse', metrics=['mae'])
history = model.fit_generator(generator=train_image_generator(train_files, batch_size),
steps_per_epoch=len(train_files)//batch_size,
validation_data=val_image_generator(val_files, batch_size),
validation_steps=len(val_files)//batch_size,
epochs=30)
保存模型
model.save("DenseNet.hdf5")