Stable Baselines/RL算法/RL基础类

Stable Baselines官方文档中文版 Github CSDN
尝试翻译官方文档,水平有限,如有错误万望指正

所有强化学习(RL)算法的公共接口

  • BaseRLModel

    class stable_baselines.common.base_class.BaseRLModel(policy, env, verbose=0, *, requires_vec_env, policy_base, policy_kwargs=None)
    

    基础RL模型。

    参数介绍:

    参数类型意义
    policyBasePolicy策略对象
    envGym environment学习环境(如果已在Gym注册,可以是str。如果为了载入已训练好的模型可以是None)
    verboseint信息显示级别:0是None;1是训练信息;2是tensorflow debug
    requires_vec_envbool此模型是否需要矢量化环境
    policy_baseBasePolicy此方法使用的基础策略

    函数介绍:

    1. action_probability()
    action_probability(observation, state=None, mask=None, actions=None, logp=False)
    

    如果actionsNone,那就从给定观测中获取模型行动的概率分布。

    依据行动空间有两种输出:

    • 离散:每种可能行动的概率
    • Box:行动输出的均值和标准差

    然而,如果actions不是None,这个函数会返回此模型采用给定行动和给定参数(观测,状态,…)的概率。对于离散行动空间,返回概率质量;对于连续行动空间,则是概率密度。这是因为在连续空间,概率质量总是0,详细解释见此链接

    参数类型意义
    observationnp.ndassay输入观测
    statenp.ndarray最新状态(可以是None,用于迭代策略)
    masknp.ndarray最新掩码(可以是None,用于迭代策略)
    actionsnp.ndarray(可选参数)计算模型为每个指定参数选择指定行动的可能性。行动和观测必须具有相同数量(设为None则返回完整的行动概率分布)
    logpbool(可选参数)当指定行动,返回log空间中的概率。如果行动是None则无影响

    ***返回:***(np.ndarray) 模型的(log)行动概率

    1. get_env()

      返回当前环境(如果没有定义则返回None

      ***返回:***(Gym Environment)当前环境

    2. get_parameter_list()

    获取模型参数的tensorflow变量

    这包含了连续训练所必须的所有变量(保存/载入)

    ***返回:***(listtensorflow变量列表

    1. get_parameters()

      获取当前模型参数作为变量名字典–>ndarray

      ***返回:*OrderedDict)变量名字典–>模型擦书的ndarray

    2. learn()

      learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name='run', reset_num_timesteps=True)
      

      返回一个训练好的模型

      参数类型意义
      total_timestepsint训练样本的总数
      seedint训练的初始种子,如果是None:保持当前种子
      callbackfunction (dict, dict)算法状态的每一步都调用的布尔函数。它接受本地或全局变量。如果返回False,终止训练
      log_intervalint记录日志之前的时间步数
      tb_log_namestr运行tensorboard日志的名字

      ***返回:***(BaseRLModel)训练好的模型

    3. 类方法 load()

    classmethod load(load_path, env=None, **kwargs)  
    

    从文件中载入模型

    参数类型意义
    load_pathstr or file-like参数保存位置
    envGym Envrionment载入模型运行的环境(如果你只是需要从一个已训练模型进行预测可以是是None)
    kwargs载入模型时对模型有改变作用的其他参数
    1. load_parameters()

      load_parameters(load_path_or_dict, exact_match=True) 
      

      从文件或字典中载入模型参数

      字典关键字应该时tensorflow变量名称,可以用get_parameters函数获取。如果exact_matchTrue,字典应该包含所有模型参数的关键字。否则,出现RunTimeError。如果时False,只有字典内的参数会被更新。

      此函数没有载入agent的超参数。

      警告:

      此函数没有更新训练器/优化器参数(例如动量)。因为使用此函数后的这种训练会导致不太理想的结果。

      参数类型意义
      load_path_or_dictstr or file-like or dict参数保存位置或参数字典
      exact_matchbool如果是True,期望载入字典包含此模型的所有参数。如果是False,只为字典提到的变量载入参数,默认True
    2. predict()

      predict(observation, state=None, mask=None, deterministic=False) 
      

      从一个观测得到模型的行动

      参数类型意义
      observationnp.ndarray输入观测
      statenp.ndarray最新状态(可以是None,用于迭代策略)
      masknp.ndarray最新掩码(可以是None,用于迭代策略)
      deterministicbool是否返回确定性的行动

      ***返回:***(np.ndarray, np.ndarray)模型的行动和下个状态(用在迭代策略)

    3. pretrain()

      pretrain(dataset, n_epochs=10, learning_rate=0.0001, adam_epsilon=1e-08, val_interval=None) 
      

      用行为克隆预训练一个模型:在给定专家数据集监督学习

      目前只支持Box和离散空间。

      参数类型意义
      datasetExpertDataset数据集管理器
      n_epochsint在训练集上的迭代次数
      learning_ratefloat学习率
      adam_epsilonfloatadam优化器的 ϵ \epsilon ϵ
      val_intervalint报告每n轮训练和验证的损失。默认,最大代数的十分之一

      ***返回:***(BaseRLModel)预训练模型

    4. save(save_path)

      将当前参数保存到文件

      参数类型意义
      save_pathstr or file-like object保存位置
    5. set_env(env)

      验证环境的有效性,如果是连贯的,设为当前环境。

      参数类型意义
      envGym Environment学习策略的环境
    6. setup_model()

      创建训练模型所需的所有函数和tensorflow图表

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值