(估计器estimator) sklearn中BaseEstimator学习记录


一、BaseEstimator

我的理解,BaseEstimator(即estimator)是sklearn中分类器和回归器等机器算法模型的基类之一,从源码来看它定义了一些列关于参数、状态、数据检测、标签和estimator的html展示的操作集。

同样的分类器和回归器的基类还有ClassifierMixin、RegressorMixin(主要定义score函数)等。

这些基类都在sklearn的base.py里,相对路径是\Lib\site-packages\sklearn\base.py

BaseEstimator的源码如下:这里忽略了具体实现,主要看的是设计的API有哪些,用途以及该类的用途

class BaseEstimator:
    """Base class for all estimators in scikit-learn.

    Notes
    -----
    All estimators should specify all the parameters that can be set
    at the class level in their ``__init__`` as explicit keyword
    arguments (no ``*args`` or ``**kwargs``).
    """

    @classmethod
    def _get_param_names(cls):
        """Get parameter names for the estimator"""
        # fetch the constructor or the original constructor before
        # deprecation wrapping if any
 

    def get_params(self, deep=True):
        """
        Get parameters for this estimator.
        """
       

    def set_params(self, **params):
        """
        Set the parameters of this estimator.

        The method works on simple estimators as well as on nested objects
        (such as :class:`~sklearn.pipeline.Pipeline`). The latter have
        parameters of the form ``<component>__<parameter>`` so that it's
        possible to update each component of a nested object.
        """
       

    def __repr__(self, N_CHAR_MAX=700):
        # N_CHAR_MAX is the (approximate) maximum number of non-blank
        # characters to render. We pass it as an optional parameter to ease
        # the tests.

        

    def __getstate__(self):
        

    def _more_tags(self):
        

    def _get_tags(self):
        

    def _check_n_features(self, X, reset):
        """Set the `n_features_in_` attribute, or check against it.
        """
        
    def _validate_data(self, X, y='no_validation', reset=True,
                       validate_separately=False, **check_params):
        """Validate input data and set or check the `n_features_in_` attribute.
        """

       

    @property
    def _repr_html_(self):
        """HTML representation of estimator.

        This is redundant with the logic of `_repr_mimebundle_`. The latter
        should be favorted in the long term, `_repr_html_` is only
        implemented for consumers who do not interpret `_repr_mimbundle_`.
        """
       

    def _repr_html_inner(self):
        """This function is returned by the @property `_repr_html_` to make
        `hasattr(estimator, "_repr_html_") return `True` or `False` depending
        on `get_config()["display"]`.
        """
        

    def _repr_mimebundle_(self, **kwargs):

二、关于BaseEstimator在sklearn中的应用和设计结构

sklearn中分类器和回归器等机器算法模型的实现所依赖的类继承关系呈现出森林状,BaseEstimator是其中一棵树的根节点(基类)

1.例子LinearModel(线性模型的基类,基于BaseEstimator定义线性模型的一些操作)

代码如下(示例):

Linear Model源码:同样忽略实现,看设计

class LinearModel(BaseEstimator, metaclass=ABCMeta):
    """Base class for Linear Models"""

    @abstractmethod
    def fit(self, X, y):
        """Fit model."""

    def _decision_function(self, X):

    def predict(self, X):
        """
        Predict using the linear model.
        """

    def _set_intercept(self, X_offset, y_offset, X_scale):
        """Set the intercept_
        """

    def _more_tags(self):
       

2.有了LinearModel和之前的base.py定义的RegressorMixin等基类,我们的线性回归器LinearRegression就可以初步定义了

LinearRegression源码:同样忽略实现,看设计

class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel):
    """
    Ordinary least squares Linear Regression.

    LinearRegression fits a linear model with coefficients w = (w1, ..., wp)
    to minimize the residual sum of squares between the observed targets in
    the dataset, and the targets predicted by the linear approximation.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.linear_model import LinearRegression
    >>> X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
    >>> # y = 1 * x_0 + 2 * x_1 + 3
    >>> y = np.dot(X, np.array([1, 2])) + 3
    >>> reg = LinearRegression().fit(X, y)
    >>> reg.score(X, y)
    1.0
    >>> reg.coef_
    array([1., 2.])
    >>> reg.intercept_
    3.0000...
    >>> reg.predict(np.array([[3, 5]]))
    array([16.])
    """
    @_deprecate_positional_args
    def __init__(self, *, fit_intercept=True, normalize=False, copy_X=True,
                 n_jobs=None, positive=False):
      

    def fit(self, X, y, sample_weight=None):
        """
        Fit linear model.
        """

       

总结

BaseEstimator是sklearn中分类器和回归器等机器算法模型的基类之一(所有等机器算法模型可以视作一个estimator),从源码来看它定义了一些列关于参数、状态、数据检测、标签和estimator的html展示的操作集。

sklearn中分类器和回归器等机器算法模型的实现所依赖的类继承关系呈现出森林状,BaseEstimator是其中一棵树的根节点(基类)。

### 关于 `sklearn.utils` 中 Tags 的功能及使用方法 在 Scikit-Learn 库中,`sklearn.utils` 提供了许多实用工具来帮助开发者构建符合 Scikit-Learn API 风格的自定义估计器和转换器[^1]。其中,Tags 是一种元数据机制,用于描述特定估计器的行为特征或约束条件。 #### Tags 的主要作用 Tags 主要通过 `_get_tags()` 方法实现,该方法返回一个字典形式的数据结构,键为标签名称,值为布尔值或其他具体信息。这些标签通常由 Scikit-Learn 自带的估计器自动设置,也可以被用户自定义扩展。以下是 Tags 的一些常见用途: 1. **指示估计器支持的功能** Tags 可以用来标记某个估计器是否支持某些特性,例如多分类、回归或多输出预测等。这有助于框架内部逻辑判断何时调用特定方法。 2. **提供额外的信息** Tags 还能传递其他重要信息,比如估计器是否需要非负输入、是否允许稀疏矩阵作为输入等。这对于优化性能或验证输入有效性非常有帮助。 3. **辅助自动化测试与验证** 在开发过程中,Scikit-Learn 利用 Tags 来执行针对性更强的单元测试,从而确保不同类型的估计器满足其预期行为标准。 #### 如何查看现有估计器的 Tags? 可以通过访问估计器实例上的属性获取其默认 Tags 设置。例如: ```python from sklearn.linear_model import LogisticRegression log_reg = LogisticRegression() print(log_reg._get_tags()) ``` 上述代码会打印出 `LogisticRegression` 类所关联的一系列 Tags 和对应的值。 #### 用户如何自定义 Tags? 当创建新的估计器类时,继承自 `BaseEstimator` 并重写 `_more_tags` 方法即可添加个性化 Tags 定义。下面是一个简单的示例演示这一过程: ```python from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils.validation import check_is_fitted class MyCustomClassifier(BaseEstimator, ClassifierMixin): def __init__(self, param_a=1): self.param_a = param_a def fit(self, X, y=None): # 假设我们有一些训练逻辑... self.classes_ = list(set(y)) return self def predict(self, X): check_is_fitted(self) # 返回随机预测结果仅作示范之用 import random return [random.choice(self.classes_) for _ in range(len(X))] @classmethod def _more_tags(cls): return { 'non_deterministic': True, 'requires_positive_X': False, 'X_types': ['2darray'], '_xfail_checks': {'check_methods_subset_invariance': 'This estimator is not invariant to subset selection.'} } ``` 在这个例子中,`MyCustomClassifier` 添加了一个名为 `_more_tags` 的类方法,它返回一组定制化的 Tags 描述此分类器的独特性质。 --- ### 示例总结 - Tags 是 Scikit-Learn 中的一种元数据机制,旨在增强估计器之间的互操作性和灵活性。 - 开发者能够通过覆盖 `_more_tags` 方法来自定义所需 Tags。 - 查看内置估计器 Tags 的方式是调用对象的 `_get_tags()` 函数[^4]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我喝AD钙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值