自定义的模型如何使用GridSearchCV()来选择参数

遇到的问题

本人设计了一个模型之后,想用CV方法来选择超参数。如果再编写CV的代码,有点重造轮子的味道,于是想到了sklearn.model_selection.GridSearchCV()。可是,直接套用上去出现了一些问题,主要是缺少了一些必要的函数,例如:scoring,get_params,set_params,于是我把必要的函数结构总结在了下面。只要按要求把这几个函数补上就可以使用GridSearchCV()方法了。

解决方案

class mymodel():
    def __init__(self, h=None, lam=1,maxiter=500, tol=1e-6):
        self.beta = None
        self.h = h
        self.dataset = None
        self.maxiter = maxiter
        self.tol = tol
        self.funvalue = None
        self.coef = None
        self.dataset = None
        self.lam=lam
        self.iteration=maxiter


    def fit(self, X_train, y_train):
		#用于训练模型参数,例如self.coef
        self.coef, self.funvalue = myfun(X_train, y_train)

    def predict(self, X_new):
		#用于根据X预测y,返回y的预测值数组
		#XXXXXXX
        return y_pre


    def get_params(self, deep=True):
        """Get parameters for this estimator.

        Parameters
        ----------
        deep : boolean, optional
            If True, will return the parameters for this estimator and
            contained subobjects that are estimators.

        Returns
        -------
        params : mapping of string to any
            Parameter names mapped to their values.
        """
        out = dict()
        for key in ['h','lam','maxiter','tol']:#这里是所用超参数的list
            value = getattr(self, key, None)
            if deep and hasattr(value, 'get_params'):
                deep_items = value.get_params().items()
                out.update((key + '__' + k, val) for k, val in deep_items)
            out[key] = value
        return out

    def set_params(self, **params):
        """Set the parameters of this estimator.

        The method works on simple estimators as well as on nested objects
        (such as pipelines). The latter have parameters of the form
        ``<component>__<parameter>`` so that it's possible to update each
        component of a nested object.

        Returns
        -------
        self
        """
        if not params:
            # Simple optimization to gain speed (inspect is slow)
            return self
        valid_params = self.get_params(deep=True)


        for key, value in params.items():
            if key not in valid_params:
                raise ValueError('Invalid parameter %s for estimator %s. '
                                 'Check the list of available parameters '
                                 'with `estimator.get_params().keys()`.' %
                                 (key, self))
            setattr(self, key, value)
            valid_params[key] = value

        return self
    def score(self, X, y, sample_weight=None):
    	#如果这里不设置score函数,可以在GridSearchCV()的scoring参数中指定
        """Returns the mean accuracy on the given test data and labels.

        In multi-label classification, this is the subset accuracy
        which is a harsh metric since you require for each sample that
        each label set be correctly predicted.

        Parameters
        ----------
        X : array-like, shape = (n_samples, n_features)
            Test samples.

        y : array-like, shape = (n_samples) or (n_samples, n_outputs)
            True labels for X.

        sample_weight : array-like, shape = [n_samples], optional
            Sample weights.

        Returns
        -------
        score : float
            Mean accuracy of self.predict(X) wrt. y.

        """
        return myloss_fun(y, self.predict(X), sample_weight=sample_weight)

GridSearchCV()的用法

参数设置

https://blog.csdn.net/weixin_34342578/article/details/92665252?depth_1-
https://blog.csdn.net/qq_41076797/article/details/102755893

返回结果

GridSearchCV 结果参数详解

**注意:**在早前的版本中有grid_scores_这个结果,后来的版本(0.18.1)中换成了cv_results_

gsearch.best_estimator_
gsearch.best_index_
gsearch.best_params_
gsearch.best_score_
gsearch.classes_
gsearch.cv
gsearch.cv_results_
gsearch.error_score
gsearch.estimator
gsearch.fit
gsearch.fit_params
gsearch.get_params
gsearch.grid_scores_
gsearch.iid
gsearch.multimetric_
gsearch.n_jobs
gsearch.n_splits_
gsearch.param_grid
gsearch.pre_dispatch
gsearch.predict
gsearch.predict_log_proba
gsearch.predict_proba
gsearch.refit
gsearch.return_train_score
gsearch.score
gsearch.scorer_
gsearch.scoring
gsearch.set_params
gsearch.verbose
## 一下三个参数随机森林分类没有
### gsearch.decision_function
### gsearch.inverse_transform
### gsearch.transform

RandomizedSearchCV()的用法

RandomizedSearchCV()就是给定参数可以选择的分布,然后在分布中随机搜索一些参数进行训练,适合在数据量很大的时候。RandomizedSearchCV模型训练后的结果cv_results_ 包含我们每次随机搜索参数得到的模型信息。
可以参考:
https://blog.csdn.net/u014248127/article/details/78938561

  • 2
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值