源码:https://github.com/franneck94/TensorCross
pip install tensorcross
from tensorcross.model_selection import GridSearch
train_dataset, val_dataset = dataset_split(
dataset=dataset,
split_fraction=(1 / 3)
)
param_grid = {
"optimizer": [
tf.keras.optimizers.Adam,
tf.keras.optimizers.RMSprop
],
"learning_rate": [0.001, 0.0001]
}
grid_search = GridSearch(
model_fn=build_model, ##你的模型
param_grid=param_grid,
verbose=1,
)
grid_search.fit(
train_dataset=train_dataset,
val_dataset=val_dataset,
epochs=1,
verbose=1
)
grid_search.summary()