from tensorflow.keras.wrappers.scikit_learn import KerasRegressor, KerasClassifier
keras有这样一个sklearn风格的接口,可以满足sklearn风格的写法。
这里仅给出回归示例:
分类用法仅需修改为KerasClassifier,并根据需要修改为metrics=[‘acc’,‘mae’,‘mse’]等即可
def build_regresor_model(lr):
"""
构建网络,并编译
"""
model = Sequential()
model.add(Dense(units=128, activation='tanh'))
model.add(Dense(units=128, activation='tanh'))
# 最后一层只有一个单元,没有激活
model.add(Dense(units=1))
model.compile(optimizer=optimizers.Adam(lr=lr),
loss='mse',
metrics=['mae', 'mse'],
)
return model
regressor = KerasRegressor(build_fn=build_regresor_model, lr=0.001, batch_size=100, nb_epoch=20)
regressor.fit(x=train_X, y=train_Y) # 训练
pred_val_Y = regressor.predict(val_X) # 在验证集上预测
build_regresor_model
作为自定义函数的指针传入,lr
是自定义函数的传参,batch_size
是keras模型训练时的参数,批处理大小,nb_epoch
是keras模型训练时的参数,从 Keras 2.0 开始,nb_epoch参数已重命名为epochs,但是这个接口似乎被忽略了,还是nb_epoch,这点需要注意。
参考:
《What does nb_epoch in neural network stands for?》
《Value error during grid search - epochs is not a legal parameter》