transformers库——HfArgumentParser

HfArgumentParser是transformers.hf_argparser.py内的一个类,进一步封装了python的argparse模块,主要处理命令行的参数。

代码解读

hf_argparser.py中首先使用typing模块下的NewType类定义了两个新的数据类型,NewType的第一个参数是数据类型的名称,该数据类型具体可以是Any类型(即各种类型)。

DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)

init函数解读

class HfArgumentParser(ArgumentParser):
    """
    argparse.ArgumentParser的子类,使用数据类(dataclasses)来生成参数。

这个类被设计成可以很好地与python的argparse配合使用。特别是,您可以在初始化之后向解析器添加更多(非数据类支持的)参数,并且您将在解析后获得作为附加名称空间的输出。可选:要创建子参数组,请在数据类中使用' _argument_group_name '属性。
    """

    dataclass_types: Iterable[DataClassType]  # 定义一个DataClassType类型的可迭代对象

    def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
        """
        Args:
            dataclass_types:
                Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
            kwargs (`Dict[str, Any]`, *optional*):
                Passed to `argparse.ArgumentParser()` in the regular way.
        """
        # To make the default appear when using --help
        if "formatter_class" not in kwargs:
            kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
        super().__init__(**kwargs)
        if dataclasses.is_dataclass(dataclass_types):	# 判断dataclass_types是否是一个数据类
            dataclass_types = [dataclass_types]		# 是,则将数据类变为一个数据类列表
        self.dataclass_types = list(dataclass_types)
        for dtype in self.dataclass_types:
            self._add_dataclass_arguments(dtype)

该类在__init__的初始化阶段将datalass_types参数传递进来的各种数据类的参数(属性)注册到argparse中,并返回argparse.ArgumentParser对象——解析器。
使用例子:parser = HfArgumentParser(_TRAIN_ARGS)

parse_dict函数解读

def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
    """
    Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
    types.

    Args:
        args (`dict`):
            dict containing config values
        allow_extra_keys (`bool`, *optional*, defaults to `False`):
            Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.

    Returns:
        Tuple consisting of:

            - the dataclass instances in the same order as they were passed to the initializer.
    """
    unused_keys = set(args.keys())		# 去重
    outputs = []
    for dtype in self.dataclass_types:		# parse中注册的数据类
        keys = {f.name for f in dataclasses.fields(dtype) if f.init}
        inputs = {k: v for k, v in args.items() if k in keys}
        unused_keys.difference_update(inputs.keys())
        obj = dtype(**inputs)
        outputs.append(obj)
    if not allow_extra_keys and unused_keys:		# 如果不允许有额外的参数但命令行传入了额外的参数的情况下,丢出错误
        raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
    return tuple(outputs)		# 返回parse中注册数据类对象元组

使用参数字典args中的各种参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
使用例子:parser.parse_dict(args)

parse_yaml_file函数

def parse_yaml_file(
    self, yaml_file: Union[str, os.PathLike], allow_extra_keys: bool = False
) -> Tuple[DataClass, ...]:
    """
    Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
    dataclass types.

    Args:
        yaml_file (`str` or `os.PathLike`):
            File name of the yaml file to parse
        allow_extra_keys (`bool`, *optional*, defaults to `False`):
            Defaults to False. If False, will raise an exception if the json file contains keys that are not
            parsed.

    Returns:
        Tuple consisting of:

            - the dataclass instances in the same order as they were passed to the initializer.
    """
    outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
    return tuple(outputs)

读取yaml文件中的各种参数值,使用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
使用例子:

if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):    # 如果sys.argv得到的参数个数等于2,说明命令只有py文件名和配置文件名
    return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))

parse_json_file函数

    def parse_json_file(
        self, json_file: Union[str, os.PathLike], allow_extra_keys: bool = False
    ) -> Tuple[DataClass, ...]:
        """
        Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
        dataclass types.

        Args:
            json_file (`str` or `os.PathLike`):
                File name of the json file to parse
            allow_extra_keys (`bool`, *optional*, defaults to `False`):
                Defaults to False. If False, will raise an exception if the json file contains keys that are not
                parsed.

        Returns:
            Tuple consisting of:

                - the dataclass instances in the same order as they were passed to the initializer.
        """
        with open(Path(json_file), encoding="utf-8") as open_json_file:
            data = json.loads(open_json_file.read())
        outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
        return tuple(outputs)

读取json文件中的各种参数值,使用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
使用例子:

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):    # 如果sys.argv得到的参数个数等于2,说明命令只有py文件名和配置文件名
    return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))

parse_arg_into_dataclasses

def parse_args_into_dataclasses(
        self,
        args=None,
        return_remaining_strings=False,
        look_for_args_file=True,
        args_filename=None,
        args_file_flag=None,
    ) -> Tuple[DataClass, ...]:

从命令行命令中获取参数,并用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值