HfArgumentParser及dataclasses类(dataclass、field、asdict、__post_init__)相关

HfArgumentParser

HfArgumentParser继承了ArgumentParser类。
经常在NLP方面的大模型参数配置中看到其身影,感觉在CV中的话也可以尝试一下。

一般的用法是用其定义数据类,然后通过读取json文件来传输参数,在知道其用法之前我们需要一些dataclassfield的知识,这两个都是属于dataclasses类,之前看源码还会有一个asdict和**post_init**的方法,在此一并都介绍了。

dataclass

是一个数据类,一般用于对数据的定义,他会默认添加__repr__及__init__。

优点

1、默认添加__init__,避免麻烦的赋值操作
正常是如下所示

class MyClass:
 def __init__(self, var_a, var_b):
 self.var_a = var_a
 self.var_b = var_b

用了dataclass只需如下写法,@dataclass默认生成__init__:

from dataclasses import dataclass

@dataclass
class MyClass:
	var_a: str = 'a'
	var_b: str = 'b'

2、其他
其他还有很多优点,可以参考官方文档及其他博客,这里就不再赘述了。

filed

主要有以下功能:
1.可以定义类属性为数据格式,例如为一个空列表。代码如下:
你可能想是这样定义

from dataclasses import dataclass
from typing import List

@dataclass
class C:
    my_list: List[int] = []  

但这样会报错。故采用下面的field定义方法:

from dataclasses import dataclass
from typing import List, Optional

@dataclass
class C:
    my_list: Optional[List[int]] = field(default_factory=list)

asdict

这个很简单啦,就是把dataclass类转换成字典dict,来看一个例子:
运行代码不会报错,说明是相等的。

from dataclasses import dataclass
from typing import List, Optional
@dataclass
class Point:
    x: int
    y: int

@dataclass
class C:
    mylist: list[Point]

p = Point(10, 20)
assert asdict(p) == {'x': 10, 'y': 20}

c = C([Point(0, 0), Point(10, 4)])
assert asdict(c) == {'mylist': [{'x': 0, 'y': 0}, {'x': 10, 'y': 4}]}

post_init

有了dataclass,需要定义一个__init__方法来将变量赋给self这种初始化操作已经得到了处理。但是我们失去了在变量被赋值之后立即需要的函数调用或处理的灵活性。
生成的__init__方法在返回之前调用__post_init__返回。因此,可以在函数中进行任何处理。
所以提供了__post_init__方法来解决这个问题,来看代码示例:

from dataclasses import dataclass, field

@dataclass
class C:
    a: float
    b: float
    c: float = field(init=False)

    def __post_init__(self):
        self.c = self.a + self.b

c1 = C(10, 20)
print("c1:{0}".format(c1))

这样就会生成c的值了,在一些深度学习代码中__post_init__可以用来验证参数初始化是否正确。

HfArgumentParser常见使用

终于,我们可以介绍HfArgumentParser了,这个用法主要常见于一些深度学习的代码,而且是在参数的定义中。主要流程如下图所示:
在这里插入图片描述
代码如下所示:

from dataclasses import dataclass, field, fields
from typing import List, Optional
from transformers import HfArgumentParser

@dataclass
class QLoRAArguments:
    """
    一些自定义参数
    """
    max_seq_length: int = field(default=1, metadata={"help": "输入最大长度"})
    task_type: str = field(default="11", metadata={"help": "预训练任务:[sft, pretrain]"})
    eval_file: Optional[str] = field(default="", metadata={"help": "the file of training data"})
    lora_rank: Optional[int] = field(default=64, metadata={"help": "lora rank"})
    lora_alpha: Optional[int] = field(default=16, metadata={"help": "lora alpha"})
    lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "lora dropout"})
    output_dir: Optional[float] = field(default='./code')
    model_name_or_path: Optional[float] = None


@dataclass
class Second:
    second: int = field(default=2,metadata={'help': '22222'})


parse = HfArgumentParser((QLoRAArguments, Second))
training_args, secondary = parse.parse_json_file(json_file='11.json')
print(training_args.task_type)
print(secondary.second)

总的来说,HfArgumentParser对比常规的ArgumentParser参数定义方法比较好的一点是其清晰,可管理性强。以上述代码为例,我们在定义自己深度学习模型时,会有训练参数、常规配置参数、验证参数等,所以我们可以分好几个dataclass类,使用上述代码去分配。

注意事项
parse_json_file()

该方法返回的是一个元祖,若你只需要返回一个,则应使用下述方法,将上述代码对应行数修改一下即可。

training_args = parse.parse_json_file(json_file='11.json')[0]
parse_args_into_dataclasses()

这是其中的另一个方法,返回的是dataclasses类,用法也基本类似,如下:

parse = HfArgumentParser(QLoRAArguments)
args = parse.parse_args_into_dataclasses()[0]

详细信息就不再赘述了,更多的方法可以去官方文档等去查。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值