使用keras对自己的模型通过GridSearchCV进行调参

GridSearchCVh是一个智能的调参工具,根据该工具我们可以让模型在一组参数里面自行选择最有的参数。一般keras自带的有对分类和回归两类问题的调参,但是对于一些复杂的模型,或者自己写的模型,如果我们想使用这个方法往往需要重构源码的函数。
class KerasRegressorMy(BaseWrapper):

def __init__(self, lr=0.0001, latent_dim=128, num_well=189, batch_size=2, epoch=1, factor=0.6):
    self.lr = lr
    self.latent_dim = latent_dim
    self.num_well = num_well
    self.batch_size = batch_size
    self.epoch = epoch
    self.factor = factor
    # self.iterate = self.makeMyModel()

def get_params(self, deep=True):
    out = dict()
    for key in ['lr', 'latent_dim', 'num_well', 'batch_size', 'epoch', 'factor']:  # 这里是所用超参数的list
        value = getattr(self, key, None)
        if deep and hasattr(value, 'get_params'):
            deep_items = value.get_params().items()
            out.update((key + '__' + k, val) for k, val in deep_items)
        out[key] = value
    # print(out)
    return out

def set_params(self, **params):
    if not params:
        # Simple optimization to gain speed (inspect is slow)
        return self
    valid_params = self.get_params(deep=True)
    for key, value in params.items():
        if key not in valid_params:
            raise ValueError('Invalid parameter %s for estimator %s. '
                             'Check the list of available parameters '
                             'with `estimator.get_params().keys()`.' %
                             (key, self))
        setattr(self, key, value)
        valid_params[key] = value
    # print("=============param==========")
    # print(valid_params)
    return self
    
    def fit(self, x, y, sample_weight=None, **kwargs):
        K.clear_session()
        self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
        fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit))
        fit_args.update(kwargs)
        history = self.model.fit(x, y, **fit_args)
        return history

    def predict(self, x, **kwargs):
        kwargs = self.filter_sk_params(Sequential.predict, kwargs)
        return np.squeeze(self.model.predict(x, **kwargs))

    def score(self, x, y, **kwargs):
        
        loss = self.model.evaluate(x, y, **kwargs)
        for name, output in zip(self.model.metrics_names, loss):
            # print(name, output)
            if name == "all_loss_total_4":
                print(output)
                return -output
        return -loss[-2]

 

然后就可以跟一般的用法一样

grid = RandomizedSearchCV(estimator=MyModel(), param_distributions=parm_grid, n_jobs=1, cv=5, refit=False)
results = grid.fit(x,y)

 

注意:

该类中一定要有上述几个方法,fit进行训练,score对每个训练进行评价。在使用searchCv进行搜索时,cv表示采用的交叉验证的折数,refit表示最终是否需要采用选择出的最有参数,一整个输入的为数据集对整个模型进行训练。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值