【STABLE_BASELINE3】强化学习算法实战:模型的保存,读取,再训练

问题由来

有时候需要在已经训练好的模型基础上进行再次训练

模型基础

链接: table-baselines3手册

使用代码

简单的保存和读取

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy


# Create environment
env = gym.make('LunarLander-v2')

# Instantiate the agent
model = DQN('MlpPolicy', env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(2e5))
# Save the agent
model.save("dqn_lunar")
del model  # delete trained model to demonstrate loading

# Load the trained agent
# NOTE: if you have loading issue, you can pass `print_system_info=True`
# to compare the system on which the model was trained vs the current one
# model = DQN.load("dqn_lunar", env=env, print_system_info=True)
model = DQN.load("dqn_lunar", env=env)

# Evaluate the agent
# NOTE: If you use wrappers with your environment that modify rewards,
#       this will be reflected here. To evaluate with original rewards,
#       wrap environment in a "Monitor" wrapper before other wrappers.
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)

# Enjoy trained agent
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, info = env.step(action)
    env.render()

复杂的读取

先保存模型

model.save("best_model")
model.save_replay_buffer("best_model_buffer")

如果已经保存好了模型,然后使用的是自定义环境

# 自定义的环境
env_ca = Env('ddpg_ca', [CA(ca())])
# 读取DDPG模型
model_ca = DDPG.load('best_model.zip')
# 读取DDPG模型的buffer
model_ca.load_replay_buffer("best_model_buffer")
# 保存模型是包含了环境的,但是自定义环境可能出问题,所以重新设置环境
model_ca.set_env(env_ca, force_reset=True)
# 然后就可以再次训练了
checkpoint_callback = CheckpointCallback(save_freq=100, save_path="./logs_re/" + env_ca.env_name + '/',
    name_prefix="rl_model", save_replay_buffer=False, save_vecnormalize=False,)
model_ca.learn(total_timesteps=20000, log_interval=10, callback=checkpoint_callback)

训练结果

第一次训练
这是第一次训练的图
再次训练
这是再次读取模型之后训练的图

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值