论文:Training Diffusion Models with Reinforcement Learning【链接】
源代码位于:github的trl包中【链接】
一、eval()
描述
eval() 函数用来执行一个字符串表达式,并返回表达式的值。字符串表达式可以包含变量、函数调用、运算符和其他 Python 语法元素。
注意:在神经网络中,用在验证集上,不计算梯度!!
注意:model.train() VS model.eval()
eval(expression[, globals[, locals]])
参数
expression – 表达式。
globals – 变量作用域,全局命名空间,如果被提供,则必须是一个字典对象。
locals – 变量作用域,局部命名空间,如果被提供,可以是任何映射对象。
实例
>>>x = 7
>>> eval( '3 * x' )
21
>>> eval('pow(2,2)')
4
>>> eval('2 + 2')
4
>>> n=81
>>> eval("n + 4")
85
二、with self.autocast():
使用自动混合精度训练(auto Mixed Precision,AMP)可以大幅度降低训练的成本并提高训练的速度。
三、Accelerator.register_save_state_pre_hook()
只能与[Accelerator.register_load_state_pre_hook
]一起使用。可用于保存
除了模型权重之外的配置。也可用于用自定义的覆盖模型保存
方法在这种情况下,请确保从权重列表中删除已加载的权重。
四、PerPromptStatTracker
用于跟踪每个提示的统计信息的类。主要用于计算DPPO算法的Advantages。
Args:
buffer_size(int)
:
为每个提示保留的缓冲区的大小。
min_count(int)
:
在计算平均值和标准值之前,要保留在缓冲区中的最小样本数。
五、NotImplementedError
在Python中,NotImplementedError
是一个内置异常类,用于表示一个方法或函数应该被实现,但实际上并没有被实现。它通常用于抽象基类(ABC)中,作为占位符,提醒子类必须覆盖这个方法。
六、accelerator.gather(tensor)
能将所有进程的tensor收集起来。
参考
Python eval() 函数
Pytorch自动混合精度(AMP)介绍与使用
【Python】进阶学习:一文了解NotImplementedError的作用
【Pytorch】model.train() 和 model.eval() 原理与用法