责任链模式通常用于流式数据的处理、请求响应中间件等场景,将多个过滤器对象连成一条链,并沿着这条链传递该请求。也可以用于面向切片编程的场景,责任链中的每个过滤器对象可以以插件的形式提供给主流程,主流程只关心责任链的构造和执行,每个过滤器插件的实现可以单独开发,通过配置方式动态加载到责任链中,实现了和主流程的解耦。
本文提供了责任链模式的两种实现。
定义主业务逻辑基类
from abc import ABCMeta, abstractmethod
class Method:
"""业务逻辑类 ."""
def process_direct(self, *args, **kwargs):
raise NotImplementedError
def process(self, *args, **kwargs):
"""业务逻辑执行入口 ."""
return self.process_direct(*args, **kwargs)
async def async_process_direct(self, *args, **kwargs):
raise NotImplementedError
async def async_process(self, *args, **kwargs):
return await self.async_process_direct(*args, **kwargs)
定义过滤器接口
class IFilter:
def process(self, chain, method: Method, *args, **kwargs):
raise NotImplementedError
async def async_process(self, chain, method: Method, *args, **kwargs):
raise NotImplementedError
定义前缀过滤器基类
class BeforeFilter(IFilter):
"""前置过滤器 ."""
def process(self, chain, method: Method, *args, **kwargs):
self.before_process(*args, **kwargs)
result = chain.process(*args, **kwargs)
return result
async def async_process(self, chain, method: Method, *args, **kwargs):
await self.async_before_process(*args, **kwargs)
result = await chain.async_process(*args, **kwargs)
return result
def before_process(self, *args, **kwargs):
raise NotImplementedError
async def async_before_process(self, *args, **kwargs):
raise NotImplementedError
定义后置过滤器基类
class AfterFilter(IFilter):
"""后置过滤器 ."""
def process(self, chain, method: Method, *args, **kwargs):
result = chain.process(*args, **kwargs)
self.after_process(*args, **kwargs)
return result
async def async_process(self, chain, method: Method, *args, **kwargs):
result = await chain.async_process(*args, **kwargs)
await self.async_after_process(*args, **kwargs)
return result
def after_process(self, *args, **kwargs):
raise NotImplementedError
async def async_after_process(self, *args, **kwargs):
raise NotImplementedError
定义责任链基类
class IFilterChain(metaclass=ABCMeta):
@abstractmethod
def next_filter(self):
raise NotImplementedError
class FilterChain(IFilterChain):
"""过滤器链 ."""
def __init__(self, method: Method):
self.method: Method = method
self.filters: list = []
self.index: int = 0
def add_filter(self, filter_obj: IFilter):
if filter_obj:
self.filters.append(filter_obj)
def next_filter(self):
filter_obj = self.filters[self.index]
self.index += 1
return filter_obj
def process(self, *args, **kwargs):
if self.index < len(self.filters):
result = self.next_filter().process(self, self.method, *args, **kwargs)
else:
result = self.method.process(*args, **kwargs)
return result
async def async_process(self, *args, **kwargs):
if self.index < len(self.filters):
result = await self.next_filter().async_process(self, self.method, *args, **kwargs)
else:
result = await self.method.async_process(*args, **kwargs)
return result
单元测试
class Context:
def __init__(self):
self.value = []
class BusinessMethod(Method):
def process_direct(self, context):
context.value.append(2)
return context
async def async_process_direct(self, context):
context.value.append(22)
return context
class BusinessBeforeFilter(BeforeFilter):
def before_process(self, context):
context.value.append(1)
async def async_before_process(self, context):
context.value.append(11)
class BusinessAfterFilter(AfterFilter):
def after_process(self, context):
context.value.append(3)
async def async_after_process(self, context):
context.value.append(33)
def test_filter_chain_process():
filter_chain = FilterChain(BusinessMethod())
filter_chain.add_filter(BusinessBeforeFilter())
filter_chain.add_filter(BusinessAfterFilter())
context = Context()
result = filter_chain.process(context)
assert result.value == [1, 2, 3]
async def test_filter_chain_awaitable_process(event_loop):
filter_chain = FilterChain(BusinessMethod())
filter_chain.add_filter(BusinessBeforeFilter())
filter_chain.add_filter(BusinessAfterFilter())
context = Context()
result = await filter_chain.async_process(context)
assert result.value == [11, 22, 33]