【问题记录】GridSearch在给Tensorflow 全连接网络调参时的坑 raise errors_impl.OpError(None, None, error_message, errors_

GridSearch是scikit-learn库的一个暴力调参工具,可以通过设定参数范围来寻找达到最优化目标的超参数,其底层原理按下不表。常见GridSearch和sklearn库中的各种机器学习算法相结合。
在使用GridSearch对tensorflow生成的神经网络模型进行调参时,其流程如下:
1.建立模型

def create_model(param):
	input=...
	output=...
	model=keras.Model(input,output)
	model.compile()
	return model

2.使用scikeras库中的KerasClassifier()方法将tensorflow模型打包为GridSearch可以直接调用的类。

import scikeras.wrappers import KerasClassifier
model=KerasClassifier(build_fn=create_model)

3.以字典形式建立需要调的参数,构建GridSearchCV对象,并利用数据集开始训练

reduceLR=ReduceLROnPlateau(monitor="accuracy",factor=0.9,patience=10,min_lr=1e-4)  #该函数用于生成学习率衰减对象
param_grid={"batch_size":np.arange(4,30,1),"callbacks":[reduceLR],"epochs":[500]}
grid=GridSearchCV(estimator=model,param_grid=param_grid,cv=10,verbose=3)
grid_result=grid.fit(train_x,train_y)

我在调一个简单的全连接神经网络时,遇到了一个错误
 raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN)
tensorflow.python.framework.errors_impl.OpError
在网上找了半天,也没有得到答案
最后看其他人的代码,发现了问题
原问题代码:

model=create_model()
model=KerasClassifier(model=model)

最后在scikeras的官网上看到官方的使用方式

def get_model(hidden_layer_dim, meta):
    # note that meta is a special argument that will be
    # handed a dict containing input metadata
    n_features_in_ = meta["n_features_in_"]
    X_shape_ = meta["X_shape_"]
    n_classes_ = meta["n_classes_"]

    model = keras.models.Sequential()
    model.add(keras.layers.Dense(n_features_in_, input_shape=X_shape_[1:]))
    model.add(keras.layers.Activation("relu"))
    model.add(keras.layers.Dense(hidden_layer_dim))
    model.add(keras.layers.Activation("relu"))
    model.add(keras.layers.Dense(n_classes_))
    model.add(keras.layers.Activation("softmax"))
    return model

clf = KerasClassifier(
    get_model,
    loss="sparse_categorical_crossentropy",
    hidden_layer_dim=100,
)

gs = GridSearchCV(clf, params, refit=False, cv=3, scoring='accuracy')

gs.fit(X, y)

官方使用是需要将构造模型的函数作为参数传入的,并且在方法KerasClassifier()的API中在这里插入图片描述
显示了,其model参数需要Callable[…,tf.keras.Model],即一个任意收入到keras.Model的一个映射函数。
因此,只需要将原代码改为

model=KerasClassifier(model=create_model)

即可。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值