Zee带你看代码系列
学习强化学习,码代码的能力必须要出众,要快速入门强化学习 搞清楚其中真正的原理,读源码是一个最简单的最直接的方式。最近创建了一系列该类型文章,希望对大家有多帮助。
另外,我会将所有的文章及所做的一些简单项目,放在我的个人网页上。
水平有限,可能有理解不到位的地方,希望大家主动沟通交流。
邮箱:zachary2wave@163.com
OPENAI Baeslines 详解(五)保存数据、模型
数据保存
把环境的数据保存下来是找问题原因的一个关键技巧,利用baselines 的函数可以轻松地保存数据成各种形式。
Baseline有两种保存数据的方式:一种是建立Monitor 一种是Callbacks, 两种办法都是可行的。
但是除了deepq保留了callbacks的接口 其他地方都没有保留,但是baselines是基于tensorflow的,所以有利器tensorboard。 这样,观察整个训练过程,变得更容易 更强大。
Monitor
Monitor 监视器,相当于将env进行一层包装Wrapper,将env 放在监视之下。
from baselines.bench import Monitor
env = Monitor(env, log_path, allow_early_resets=True)
# 输入的env为gym.make创建的,如果是多env环境会报错。
# log—path 是保存当前环境的地方。
完全未修改的监视器,只能输出 平均reward 、训练时常 和 所利用时间。
当然不能满足我们的需求。最简单的办法 修改源代码。
在bench中 找到montior ,然后找到step和update函数 。
update 的输入中,中加入任何你要记录的东西,并将其加入之后的字典变量epinfo。,比如说:
def update(self, ob, rew, done, info, action): #58行
epinfo = {"ob": ob, "action": action, 're': rew ,'done': done , "t": round(time.time() - self.tstart, 6)}、
并更新在step中的调用update的时候的输入 。
之后,需要在112行 中fieldname中加入:
self.logger = csv.DictWriter(self.f, fieldnames=('ob', 'action', 're', 'done', 't')+tuple(extra_keys))
Tensorboard
其实训练日志,可以被输出为各种形式,其中有
'stdout' # 默认形式
'log' # txt
'json'
'csv'
'tensorboard' #
Tensorboard 的实用 可以直接配置成如下形式。
当然可以选择多种 输出形式,
logger.configure(dir=log_path,format_strs=['tensorboard'])
然后 在命令行中 , 先激活环境然后,再配置tensorboard 。
tensorboard --logdir log_path
之后打开网址http://localhost:6006
模型保存
baseline (DQN)会具有比较好的模型保存功能。每运行一段时间,如果存在当前得到的reward比较好的情况下就会保存当前模型。部分算法是没有的这个功能的,但是保存和调用变量的代码是存在common文件夹下的,可以直接Ctrl+C 、Ctrl+V。 直接调用就好了。可以不用写 variables tf 可以直接读取当前网络。
from baselines.common.tf_util import load_variables, save_variables
save_variables(save_path, variables=None, sess=None)
load_variables(load_path, variables=None, sess=None)