问题由来
有时候需要在已经训练好的模型基础上进行再次训练
模型基础
使用代码
简单的保存和读取
import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make('LunarLander-v2')
# Instantiate the agent
model = DQN('MlpPolicy', env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(2e5))
# Save the agent
model.save("dqn_lunar")
del model # delete trained model to demonstrate loading
# Load the trained agent
# NOTE: if you have loading issue, you can pass `print_system_info=True`
# to compare the system on which the model was trained vs the current one
# model = DQN.load("dqn_lunar", env=env, print_system_info=True)
model = DQN.load("dqn_lunar", env=env)
# Evaluate the agent
# NOTE: If you use wrappers with your environment that modify rewards,
# this will be reflected here. To evaluate with original rewards,
# wrap environment in a "Monitor" wrapper before other wrappers.
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# Enjoy trained agent
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(action)
env.render()
复杂的读取
先保存模型
model.save("best_model")
model.save_replay_buffer("best_model_buffer")
如果已经保存好了模型,然后使用的是自定义环境
# 自定义的环境
env_ca = Env('ddpg_ca', [CA(ca())])
# 读取DDPG模型
model_ca = DDPG.load('best_model.zip')
# 读取DDPG模型的buffer
model_ca.load_replay_buffer("best_model_buffer")
# 保存模型是包含了环境的,但是自定义环境可能出问题,所以重新设置环境
model_ca.set_env(env_ca, force_reset=True)
# 然后就可以再次训练了
checkpoint_callback = CheckpointCallback(save_freq=100, save_path="./logs_re/" + env_ca.env_name + '/',
name_prefix="rl_model", save_replay_buffer=False, save_vecnormalize=False,)
model_ca.learn(total_timesteps=20000, log_interval=10, callback=checkpoint_callback)
训练结果
第一次训练
再次训练