1-需要正确安装的包
最新要求,https://github.com/adriangb/scikeras
要想使用tf.keras.wrappers.scikit_learn的 KerasClassifier, KerasRegressor,打包自定义的模型
需安装如下的包
pip install scikeras[tensorflow]
2-import规则
最新要求,https://adriangb.com/scikeras/stable/migration.html
导入方式为:
from scikeras.wrappers import KerasClassifier, KerasRegressor
3-关于def构建model,是否compile,有2种说法
3.1 def包含compile,则在SciKeras中不在编译model,并且参数无法传入。
Compile your model within model_build_fn and return this compiled model. In this case, SciKeras will not re-compile your model and all compilation parameters (such as optimizer) given to scikeras.wrappers.BaseWrapper.init() will be ignored.
3.2def不包含compile,则在SciKeras中编译model,并且传入相关参数,便于随机搜索和网格搜索参数的传入。
Return an uncompiled model from model_build_fn and let SciKeras handle the compilation. In this case, SciKeras will apply all of the compilation parameters, including instantiating losses, metrics and optimizers.
详细参考https://adriangb.com/scikeras/stable/advanced.html#compilation-of-model
4-参数传入网格搜索或随机搜索,打包时参数的,采用‘__’传入方式,例如下方示例。
详细参考https://adriangb.com/scikeras/stable/advanced.html#optimizer
clf = KerasClassifier(
model=model_build_fn,
optimizer=keras.optimizers.SGD,
optimizer__learning_rate=0.05
)