scikit-learn中为各preprocessor、classsifier和estimatior涉及了统一的fit
、transform
和fit_transform
接口,方便了用户的理解和调用。本文模拟其设计理念,设计了一套fit和transform流水线。
import typing
import abc
import functools
import numpy as np
def transform_validate(func: typing.Callable):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if not self._context:
raise NotImplementedError(f"Fit function of {self.__class__.__name__} hasn't been implemented!")
return func(self, *args, **kwargs)
return wrapper
def chain_fit_transform(processors: typing.List['Processor']):
"""串联若干组processor"""
@functools.wraps(chain_fit_transform)
def wrapper(args):
for p in processors:
args = p.fit_transform(args)
return args
processors_names = "->".join(p.__class__.__name__ for p in processors)
wrapper.__name__ += ' including ' + processors_names
return wrapper
class Processor(abc.ABC):
def __init__(self):
self._context = {} # save context data for transform
@abc.abstractmethod
def fit(self, *args, **kwargs):
"""fit function, expect return self"""
@abc.abstractmethod
def transform(self, *args, **kwargs):
"""must be used after fit"""
def fit_transform(self, *args, **kwargs):
"""unify of fit and transform"""
return self.fit(*args, **kwargs).transform(*args, **kwargs)
class MeanProcessor(Processor):
"""取平均值"""
def fit(self, data: np.ndarray):
assert data.ndim == 2
self._context['mean'] = np.mean(data, axis=0)
return self
@transform_validate
def transform(self, data: np.ndarray):
assert data.ndim == 2
return data - self._context['mean']
class OffsetProcessor(Processor):
"""
整体偏移一个值
"""
def __init__(self, value=0):
super(OffsetProcessor, self).__init__()
self._context['offset'] = value
def fit(self, data: np.ndarray):
return self
def transform(self, data: np.ndarray):
return data - self._context['offset']
if __name__ == '__main__':
a = np.arange(12).reshape(3, 4)
meanpro = MeanProcessor()
offpro = OffsetProcessor(10)
chained_processors = chain_fit_transform([meanpro, offpro])
result = chained_processors(a)
print(result)
涉及到的主要功能包括:
(1)transform_validate装饰器函数:用于保证某些processor在进行transform之前必须先fit;
(2)chain_fit_transform:采用闭包结构,将若干个processor进行拼接,统一调用各子processor的fit_transform函数;
(3)Processor:所有processors的抽象基类,提供抽象函数fit和transform,前者返回processor实例本身,后者进行数据变换,fit和transform间通过实例的_context提供相应的值;