【python进阶】模仿scikit-learn打造fit和transform流水线

scikit-learn中为各preprocessor、classsifier和estimatior涉及了统一的fittransformfit_transform接口,方便了用户的理解和调用。本文模拟其设计理念,设计了一套fit和transform流水线。

import typing
import abc
import functools
import numpy as np


def transform_validate(func: typing.Callable):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        if not self._context:
            raise NotImplementedError(f"Fit function of {self.__class__.__name__} hasn't been implemented!")
        return func(self, *args, **kwargs)
    return wrapper


def chain_fit_transform(processors: typing.List['Processor']):
	"""串联若干组processor"""

    @functools.wraps(chain_fit_transform)
    def wrapper(args):
        for p in processors:
            args = p.fit_transform(args)
        return args

    processors_names = "->".join(p.__class__.__name__ for p in processors)
    wrapper.__name__ += ' including ' + processors_names
    return wrapper


class Processor(abc.ABC):
    def __init__(self):
        self._context = {}    # save context data for transform

    @abc.abstractmethod
    def fit(self, *args, **kwargs):
        """fit function, expect return self"""

    @abc.abstractmethod
    def transform(self, *args, **kwargs):
        """must be used after fit"""

    def fit_transform(self, *args, **kwargs):
        """unify of fit and transform"""
        return self.fit(*args, **kwargs).transform(*args, **kwargs)


class MeanProcessor(Processor):
    """取平均值"""
    def fit(self, data: np.ndarray):
        assert data.ndim == 2
        self._context['mean'] = np.mean(data, axis=0)
        return self

    @transform_validate
    def transform(self, data: np.ndarray):
        assert data.ndim == 2
        return data - self._context['mean']


class OffsetProcessor(Processor):
    """
    整体偏移一个值
    """
    def __init__(self, value=0):
        super(OffsetProcessor, self).__init__()
        self._context['offset'] = value

    def fit(self, data: np.ndarray):
        return self

    def transform(self, data: np.ndarray):
        return data - self._context['offset']


if __name__ == '__main__':
    a = np.arange(12).reshape(3, 4)
    meanpro = MeanProcessor()
    offpro = OffsetProcessor(10)
    chained_processors = chain_fit_transform([meanpro, offpro])
    result = chained_processors(a)
    print(result)

涉及到的主要功能包括:
(1)transform_validate装饰器函数:用于保证某些processor在进行transform之前必须先fit;
(2)chain_fit_transform:采用闭包结构,将若干个processor进行拼接,统一调用各子processor的fit_transform函数;
(3)Processor:所有processors的抽象基类,提供抽象函数fit和transform,前者返回processor实例本身,后者进行数据变换,fit和transform间通过实例的_context提供相应的值;

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值