【学习笔记7】阅读ddpo代码

论文:Training Diffusion Models with Reinforcement Learning【链接
源代码位于:github的trl包中【链接

一、eval()

描述

eval() 函数用来执行一个字符串表达式,并返回表达式的值。字符串表达式可以包含变量、函数调用、运算符和其他 Python 语法元素。
注意:在神经网络中,用在验证集上,不计算梯度!!
注意:model.train() VS model.eval()

eval(expression[, globals[, locals]])

参数

expression – 表达式。
globals – 变量作用域,全局命名空间,如果被提供,则必须是一个字典对象。
locals – 变量作用域,局部命名空间,如果被提供,可以是任何映射对象。

实例

>>>x = 7
>>> eval( '3 * x' )
21
>>> eval('pow(2,2)')
4
>>> eval('2 + 2')
4
>>> n=81
>>> eval("n + 4")
85

二、with self.autocast():

使用自动混合精度训练(auto Mixed Precision,AMP)可以大幅度降低训练的成本并提高训练的速度。

三、Accelerator.register_save_state_pre_hook()

只能与[Accelerator.register_load_state_pre_hook]一起使用。可用于保存
除了模型权重之外的配置。也可用于用自定义的覆盖模型保存
方法在这种情况下,请确保从权重列表中删除已加载的权重。

四、PerPromptStatTracker

用于跟踪每个提示的统计信息的类。主要用于计算DPPO算法的Advantages。
Args:
buffer_size(int)
为每个提示保留的缓冲区的大小。
min_count(int)
在计算平均值和标准值之前,要保留在缓冲区中的最小样本数。

五、NotImplementedError

在Python中,NotImplementedError是一个内置异常类,用于表示一个方法或函数应该被实现,但实际上并没有被实现。它通常用于抽象基类(ABC)中,作为占位符,提醒子类必须覆盖这个方法。

六、accelerator.gather(tensor)

能将所有进程的tensor收集起来。

参考

Python eval() 函数
Pytorch自动混合精度(AMP)介绍与使用
【Python】进阶学习:一文了解NotImplementedError的作用
【Pytorch】model.train() 和 model.eval() 原理与用法

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值