GridSearchCVh是一个智能的调参工具,根据该工具我们可以让模型在一组参数里面自行选择最有的参数。一般keras自带的有对分类和回归两类问题的调参,但是对于一些复杂的模型,或者自己写的模型,如果我们想使用这个方法往往需要重构源码的函数。
class KerasRegressorMy(BaseWrapper):
def __init__(self, lr=0.0001, latent_dim=128, num_well=189, batch_size=2, epoch=1, factor=0.6): self.lr = lr self.latent_dim = latent_dim self.num_well = num_well self.batch_size = batch_size self.epoch = epoch self.factor = factor # self.iterate = self.makeMyModel() def get_params(self, deep=True): out = dict() for key in ['lr', 'latent_dim', 'num_well', 'batch_size', 'epoch', 'factor']: # 这里是所用超参数的list value = getattr(self, key, None) if deep and hasattr(value, 'get_params'): deep_items = value.get_params().items() out.update((key + '__' + k, val) for k, val in deep_items) out[key] = value # print(out) return out def set_params(self, **params): if not params: # Simple optimization to gain speed (inspect is slow) return self valid_params = self.get_params(deep=True) for key, value in params.items(): if key not in valid_params: raise ValueError('Invalid parameter %s for estimator %s. ' 'Check the list of available parameters ' 'with `estimator.get_params().keys()`.' % (key, self)) setattr(self, key, value) valid_params[key] = value # print("=============param==========") # print(valid_params) return self def fit(self, x, y, sample_weight=None, **kwargs): K.clear_session() self.model = self.build_fn(**self.filter_sk_params(self.build_fn)) fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit)) fit_args.update(kwargs) history = self.model.fit(x, y, **fit_args) return history def predict(self, x, **kwargs): kwargs = self.filter_sk_params(Sequential.predict, kwargs) return np.squeeze(self.model.predict(x, **kwargs)) def score(self, x, y, **kwargs): loss = self.model.evaluate(x, y, **kwargs) for name, output in zip(self.model.metrics_names, loss): # print(name, output) if name == "all_loss_total_4": print(output) return -output return -loss[-2]
然后就可以跟一般的用法一样
grid = RandomizedSearchCV(estimator=MyModel(), param_distributions=parm_grid, n_jobs=1, cv=5, refit=False) results = grid.fit(x,y)
注意:
该类中一定要有上述几个方法,fit进行训练,score对每个训练进行评价。在使用searchCv进行搜索时,cv表示采用的交叉验证的折数,refit表示最终是否需要采用选择出的最有参数,一整个输入的为数据集对整个模型进行训练。