sklearn 估计器estimator BaseEstimator学习记录
一、BaseEstimator
我的理解,BaseEstimator(即estimator)是sklearn中分类器和回归器等机器算法模型的基类之一,从源码来看它定义了一些列关于参数、状态、数据检测、标签和estimator的html展示的操作集。
同样的分类器和回归器的基类还有ClassifierMixin、RegressorMixin(主要定义score函数)等。
这些基类都在sklearn的base.py里,相对路径是\Lib\site-packages\sklearn\base.py
BaseEstimator的源码如下:这里忽略了具体实现,主要看的是设计的API有哪些,用途以及该类的用途
class BaseEstimator:
"""Base class for all estimators in scikit-learn.
Notes
-----
All estimators should specify all the parameters that can be set
at the class level in their ``__init__`` as explicit keyword
arguments (no ``*args`` or ``**kwargs``).
"""
@classmethod
def _get_param_names(cls):
"""Get parameter names for the estimator"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
def get_params(self, deep=True):
"""
Get parameters for this estimator.
"""
def set_params(self, **params):
"""
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects
(such as :class:`~sklearn.pipeline.Pipeline`). The latter have
parameters of the form ``<component>__<parameter>`` so that it's
possible to update each component of a nested object.
"""
def __repr__(self, N_CHAR_MAX=700):
# N_CHAR_MAX is the (approximate) maximum number of non-blank
# characters to render. We pass it as an optional parameter to ease
# the tests.
def __getstate__(self):
def _more_tags(self):
def _get_tags(self):
def _check_n_features(self, X, reset):
"""Set the `n_features_in_` attribute, or check against it.
"""
def _validate_data(self, X, y='no_validation', reset=True,
validate_separately=False, **check_params):
"""Validate input data and set or check the `n_features_in_` attribute.
"""
@property
def _repr_html_(self):
"""HTML representation of estimator.
This is redundant with the logic of `_repr_mimebundle_`. The latter
should be favorted in the long term, `_repr_html_` is only
implemented for consumers who do not interpret `_repr_mimbundle_`.
"""
def _repr_html_inner(self):
"""This function is returned by the @property `_repr_html_` to make
`hasattr(estimator, "_repr_html_") return `True` or `False` depending
on `get_config()["display"]`.
"""
def _repr_mimebundle_(self, **kwargs):
二、关于BaseEstimator在sklearn中的应用和设计结构
sklearn中分类器和回归器等机器算法模型的实现所依赖的类继承关系呈现出森林状,BaseEstimator是其中一棵树的根节点(基类)
1.例子LinearModel(线性模型的基类,基于BaseEstimator定义线性模型的一些操作)
代码如下(示例):
Linear Model源码:同样忽略实现,看设计
class LinearModel(BaseEstimator, metaclass=ABCMeta):
"""Base class for Linear Models"""
@abstractmethod
def fit(self, X, y):
"""Fit model."""
def _decision_function(self, X):
def predict(self, X):
"""
Predict using the linear model.
"""
def _set_intercept(self, X_offset, y_offset, X_scale):
"""Set the intercept_
"""
def _more_tags(self):
2.有了LinearModel和之前的base.py定义的RegressorMixin等基类,我们的线性回归器LinearRegression就可以初步定义了
LinearRegression源码:同样忽略实现,看设计
class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel):
"""
Ordinary least squares Linear Regression.
LinearRegression fits a linear model with coefficients w = (w1, ..., wp)
to minimize the residual sum of squares between the observed targets in
the dataset, and the targets predicted by the linear approximation.
Examples
--------
>>> import numpy as np
>>> from sklearn.linear_model import LinearRegression
>>> X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
>>> # y = 1 * x_0 + 2 * x_1 + 3
>>> y = np.dot(X, np.array([1, 2])) + 3
>>> reg = LinearRegression().fit(X, y)
>>> reg.score(X, y)
1.0
>>> reg.coef_
array([1., 2.])
>>> reg.intercept_
3.0000...
>>> reg.predict(np.array([[3, 5]]))
array([16.])
"""
@_deprecate_positional_args
def __init__(self, *, fit_intercept=True, normalize=False, copy_X=True,
n_jobs=None, positive=False):
def fit(self, X, y, sample_weight=None):
"""
Fit linear model.
"""
总结
BaseEstimator是sklearn中分类器和回归器等机器算法模型的基类之一(所有等机器算法模型可以视作一个estimator),从源码来看它定义了一些列关于参数、状态、数据检测、标签和estimator的html展示的操作集。
sklearn中分类器和回归器等机器算法模型的实现所依赖的类继承关系呈现出森林状,BaseEstimator是其中一棵树的根节点(基类)。