在现代软件开发中,模块化设计是提高代码可维护性和可扩展性的关键技术之一。本文将探讨如何使用注册表(Registry)和装饰器函数(Decorator Function)来实现模块化设计,提升代码的灵活性和可扩展性。
什么是注册表(Registry)?
注册表是一种设计模式,用于集中管理和访问不同模块。在 Python 中,注册表通常是一个字典,用来存储模块名称和模块类对象的映射关系。通过注册表,可以方便地对模块进行注册、获取和列出操作,从而实现模块化管理。
基本的注册表实现
以下是一个基本的注册表实现:
from loguru import logger
class Registry(object):
"""This class is used to register some modules to registry by a repo name."""
def __init__(self, name: str):
self._name = name
self._modules = {}
@property
def name(self):
return self._name
@property
def modules(self):
return self._modules
def list(self):
for m in self._modules.keys():
logger.info(f'{self._name}\t{m}')
def get(self, module_key):
return self._modules.get(module_key, None)
def _register_module(self, module_name=None, module_cls=None, force=False):
if module_name is None:
module_name = module_cls.__name__
if module_name in self._modules and not force:
raise KeyError(f'{module_name} is already registered in {self._name}')
self._modules[module_name] = module_cls
module_cls._name = module_name
def register_module(self, module_name: str = None, module_cls: type = None, force=False):
if not (module_name is None or isinstance(module_name, str)):
raise TypeError(f'module_name must be either of None, str, got {type(module_name)}')
if module_cls is not None:
self._register_module(module_name=module_name, module_cls=module_cls, force=force)
return module_cls
def _register(module_cls):
self._register_module(module_name=module_name, module_cls=module_cls, force=force)
return module_cls
return _register
特别注意register_module中的_register,register_module
方法根据传入的 module_name
和 module_cls
参数,返回一个装饰器函数 _register
。
使用装饰器函数注册模块
装饰器函数是一种高阶函数,可以在不改变原函数或类的前提下,动态地增加功能。在注册表模式中,装饰器函数可以用于注册模块,简化模块注册的过程。
以下是如何使用装饰器函数注册模块的示例:
DATASETS = Registry('Datasets')
@DATASETS.register_module('ChineseDataset')
class ChineseDataset:
def __init__(self, oss_dir, max_sample_per_cat=50, shuffle=False):
self.oss_dir = oss_dir
self.max_sample_per_cat = max_sample_per_cat
self.shuffle = shuffle
def info(self):
# 实现具体的逻辑
pass
def format(self, res):
# 实现具体的逻辑
pass
在这个示例中,ChineseDataset
类通过装饰器函数 @DATASETS.register_module('ChineseDataset')
自动注册到 DATASETS
注册表中。
装饰器函数的工作原理
在上面的代码中,@DATASETS.register_module('ChineseDataset')
是一个装饰器函数调用。这个装饰器函数的工作原理如下:
DATASETS.register_module('ChineseDataset')
返回一个装饰器函数_register
。@_register
装饰器应用于ChineseDataset
类,将其作为参数传递给_register
函数。_register
函数内部调用_register_module
方法,将ChineseDataset
类注册到DATASETS
注册表中。
具体的 _register
函数定义如下:
def register_module(self, module_name: str = None, module_cls: type = None, force=False):
if not (module_name is None or isinstance(module_name, str)):
raise TypeError(f'module_name must be either of None, str, got {type(module_name)}')
if module_cls is not None:
self._register_module(module_name=module_name, module_cls=module_cls, force=force)
return module_cls
def _register(module_cls):
self._register_module(module_name=module_name, module_cls=module_cls, force=force)
return module_cls
return _register
在这个实现中,register_module
方法根据传入的 module_name
和 module_cls
参数,返回一个装饰器函数 _register
。当 ChineseDataset
类被装饰时,_register
函数将其注册到 DATASETS
注册表中。
实例化 ChineseDataset
类
实例化 ChineseDataset
类的步骤发生在 deal_one_dataset
函数中,通过以下代码实现:
ds = DATASETS.get(name)(oss_path)
这段代码首先通过 DATASETS.get(name)
获取注册的 ChineseDataset
类,然后通过 (oss_path)
对其进行实例化。这样,ChineseDataset
类的一个实例就被创建,并存储在 ds
变量中,用于后续的数据处理操作。
为了更清晰地解释,让我们回顾一下实例化 ChineseDataset
类的完整过程。在下面的代码片段中,实例化过程发生在 deal_one_dataset
函数中:
import os, json, argparse
from tqdm import tqdm
import multiprocessing as mp
from src import DATASETS, utils
def parse_args():
parser = argparse.ArgumentParser(description='data_parsing', formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--name", type=str, default='ChineseDataset', help="dataset name")
parser.add_argument("--oss-dir", type=str, default='/root/dengbing/bigdata/benchmark/extract/ChineseDataset/', help="oss directory of the dataset")
parser.add_argument("--save-path", type=str, default='output.jsonl', help="path for the generated jsonl")
args = parser.parse_args()
return args
def deal_one_dataset(name, oss_path):
ds = DATASETS.get(name)(oss_path) # 这里实例化 ChineseDataset 类
res = ds.info()
res = ds.format(res)
return res
if __name__ == "__main__":
args = parse_args()
with open(args.save_path, 'w') as outfile:
res = deal_one_dataset(args.name, args.oss_dir)
for r in res:
j = utils.format_decimals(r)
outfile.write(json.dumps(j, ensure_ascii=False) + '\n')
具体步骤
-
解析命令行参数:
parse_args
函数解析命令行参数,获取数据集名称(--name
)、OSS 目录(--oss-dir
)和保存路径(--save-path
)。
-
处理单个数据集:
deal_one_dataset
函数根据传入的name
和oss_path
参数,处理单个数据集。- 在这一步中,
DATASETS.get(name)
获取已经注册的ChineseDataset
类,然后通过(oss_path)
对其进行实例化。
-
实例化
ChineseDataset
类:ds = DATASETS.get(name)(oss_path)
实际上是两步操作:DATASETS.get(name)
获取注册在DATASETS
注册表中的ChineseDataset
类。(oss_path)
调用类的构造函数,创建一个ChineseDataset
类的实例,并传递oss_path
参数。
-
调用数据集方法:
res = ds.info()
调用实例的info
方法,获取数据集信息。res = ds.format(res)
调用实例的format
方法,格式化数据集信息。
-
保存结果:
- 主程序部分将处理后的数据集结果保存到指定的 JSONL 文件中。
总结
通过使用装饰器函数,ChineseDataset
类可以自动注册到 DATASETS
注册表中,而不需要显式地调用注册方法。这种方式简化了模块的注册过程,使代码更加简洁和易于维护。
装饰器函数和注册表的结合使用,提高了代码的灵活性和可扩展性,使得模块化设计更加高效。在实际开发中,这种设计模式可以广泛应用于插件系统、数据处理管道和其他需要动态管理模块的场景。