class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
arguments to the parser after initialization and you'll get the output back after parsing as an additional
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
"""
dataclass_types: Iterable[DataClassType]
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
"""
Args:
dataclass_types:
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
kwargs (`Dict[str, Any]`, *optional*):
Passed to `argparse.ArgumentParser()` in the regular way.
"""
# To make the default appear when using --help
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
super().__init__(**kwargs)
if dataclasses.is_dataclass(dataclass_types):
dataclass_types = [dataclass_types]
self.dataclass_types = list(dataclass_types)
for dtype in self.dataclass_types:
self._add_dataclass_arguments(dtype)
@staticmethod
def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
field_name = f"--{field.name}"
kwargs = field.metadata.copy()
# field.metadata is not used at all by Data Classes,
# it is provided as a third-party extension mechanism.
if isinstance(field.type, str):
raise RuntimeError(
"Unresolved type detected, which should have been done with the help of "
"`typing.get_type_hints` method by default"
)
aliases = kwargs.pop("aliases", [])
if isinstance(aliases, str):
aliases = [aliases]
逐行解释这段代码。
类定义和文档字符串
python
复制
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
arguments to the parser after initialization and you'll get the output back after parsing as an additional
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
"""
- 定义一个名为
HfArgumentParser
的类,继承自ArgumentParser
。 - 该类的文档字符串解释了它的功能:使用数据类(dataclass)上的类型提示来生成参数。
- 该类与原生的
argparse
兼容,可以在初始化后添加更多非数据类支持的参数,并在解析后作为附加命名空间返回。 - 可选地,可以使用数据类中的
_argument_group_name
属性创建子参数组。
dataclass_types
属性
python
复制
dataclass_types: Iterable[DataClassType]
- 类属性
dataclass_types
,其类型为Iterable[DataClassType]
,表示该类支持多个数据类类型。
初始化方法
python
复制
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
"""
Args:
dataclass_types:
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
kwargs (`Dict[str, Any]`, *optional*):
Passed to `argparse.ArgumentParser()` in the regular way.
"""
- 初始化方法
__init__
,接受两个参数:dataclass_types
:单个数据类类型或数据类类型的列表,这些数据类将用解析后的参数实例化。kwargs
:可选的字典,传递给argparse.ArgumentParser()
。
设置默认格式化类
python
复制
# To make the default appear when using --help
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
- 检查
kwargs
中是否包含formatter_class
,如果没有,则设置为ArgumentDefaultsHelpFormatter
,以便在使用--help
时显示默认值。
调用父类的初始化方法
python
复制
super().__init__(**kwargs)
- 调用父类
ArgumentParser
的初始化方法,将kwargs
传递给它。
检查和转换 dataclass_types
python
复制
if dataclasses.is_dataclass(dataclass_types):
dataclass_types = [dataclass_types]
self.dataclass_types = list(dataclass_types)
- 检查
dataclass_types
是否是单个数据类,如果是,则将其转换为列表。 - 将
dataclass_types
转换为列表并赋值给实例属性self.dataclass_types
。
添加数据类的参数
python
复制
for dtype in self.dataclass_types:
self._add_dataclass_arguments(dtype)
- 遍历
self.dataclass_types
,并为每个数据类类型调用_add_dataclass_arguments
方法,以添加数据类的参数。
静态方法 _parse_dataclass_field
python
复制
@staticmethod
def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
field_name = f"--{field.name}"
kwargs = field.metadata.copy()
# field.metadata is not used at all by Data Classes,
# it is provided as a third-party extension mechanism.
if isinstance(field.type, str):
raise RuntimeError(
"Unresolved type detected, which should have been done with the help of "
"`typing.get_type_hints` method by default"
)
aliases = kwargs.pop("aliases", [])
if isinstance(aliases, str):
aliases = [aliases]
- 定义静态方法
_parse_dataclass_field
,用于解析数据类字段。 - 参数:
parser
:ArgumentParser
实例。field
: 数据类字段。
- 创建字段名,格式为
--{field.name}
。 - 复制字段的
metadata
属性到kwargs
。 - 检查字段类型是否为字符串,如果是,则抛出运行时错误,提示类型未解析。
- 从
kwargs
中获取别名(aliases
),如果是字符串,则转换为列表。
通过以上逐行解释,我们详细了解了 HfArgumentParser
类的定义和功能。