HfArgumentParser是transformers.hf_argparser.py内的一个类,进一步封装了python的argparse模块,主要处理命令行的参数。
代码解读
hf_argparser.py中首先使用typing模块下的NewType类定义了两个新的数据类型,NewType的第一个参数是数据类型的名称,该数据类型具体可以是Any类型(即各种类型)。
DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
init函数解读
class HfArgumentParser(ArgumentParser):
"""
argparse.ArgumentParser的子类,使用数据类(dataclasses)来生成参数。
这个类被设计成可以很好地与python的argparse配合使用。特别是,您可以在初始化之后向解析器添加更多(非数据类支持的)参数,并且您将在解析后获得作为附加名称空间的输出。可选:要创建子参数组,请在数据类中使用' _argument_group_name '属性。
"""
dataclass_types: Iterable[DataClassType] # 定义一个DataClassType类型的可迭代对象
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
"""
Args:
dataclass_types:
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
kwargs (`Dict[str, Any]`, *optional*):
Passed to `argparse.ArgumentParser()` in the regular way.
"""
# To make the default appear when using --help
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
super().__init__(**kwargs)
if dataclasses.is_dataclass(dataclass_types): # 判断dataclass_types是否是一个数据类
dataclass_types = [dataclass_types] # 是,则将数据类变为一个数据类列表
self.dataclass_types = list(dataclass_types)
for dtype in self.dataclass_types:
self._add_dataclass_arguments(dtype)
该类在__init__
的初始化阶段将datalass_types参数传递进来的各种数据类的参数(属性)注册到argparse中,并返回argparse.ArgumentParser对象——解析器。
使用例子:parser = HfArgumentParser(_TRAIN_ARGS)
parse_dict函数解读
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
types.
Args:
args (`dict`):
dict containing config values
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
unused_keys = set(args.keys()) # 去重
outputs = []
for dtype in self.dataclass_types: # parse中注册的数据类
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in args.items() if k in keys}
unused_keys.difference_update(inputs.keys())
obj = dtype(**inputs)
outputs.append(obj)
if not allow_extra_keys and unused_keys: # 如果不允许有额外的参数但命令行传入了额外的参数的情况下,丢出错误
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs) # 返回parse中注册数据类对象元组
使用参数字典args中的各种参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
使用例子:parser.parse_dict(args)
parse_yaml_file函数
def parse_yaml_file(
self, yaml_file: Union[str, os.PathLike], allow_extra_keys: bool = False
) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
dataclass types.
Args:
yaml_file (`str` or `os.PathLike`):
File name of the yaml file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
return tuple(outputs)
读取yaml文件中的各种参数值,使用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
使用例子:
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): # 如果sys.argv得到的参数个数等于2,说明命令只有py文件名和配置文件名
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
parse_json_file函数
def parse_json_file(
self, json_file: Union[str, os.PathLike], allow_extra_keys: bool = False
) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
dataclass types.
Args:
json_file (`str` or `os.PathLike`):
File name of the json file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
with open(Path(json_file), encoding="utf-8") as open_json_file:
data = json.loads(open_json_file.read())
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
return tuple(outputs)
读取json文件中的各种参数值,使用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
使用例子:
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # 如果sys.argv得到的参数个数等于2,说明命令只有py文件名和配置文件名
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
parse_arg_into_dataclasses
def parse_args_into_dataclasses(
self,
args=None,
return_remaining_strings=False,
look_for_args_file=True,
args_filename=None,
args_file_flag=None,
) -> Tuple[DataClass, ...]:
从命令行命令中获取参数,并用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。