最近刚刚入坑强化学习,对于新手来说,真的好难!好不容易理解了强化学习算法的的训练过程,但是训练好总要用吧,怎么用呢?想在网上找到训练好的模型如何测试几乎找不到(哭哭),大佬们可能觉得这很简单吧,但是小的我是真不会。不过总不能还没开始就放弃吧,在不断尝试和搜索之后,finally,找到了解决方法。好了话不多说,我来记录一下强化学习到底是如何保存并测试模型的。
前言:听说对于新手来说,pytorch比较好入手,所以以下内容是基于torch的。
1、有一个训练好的模型,这个网上资料很多,如果想要快速熟悉强化学习整个算法流程的话,可以先download的一个试试看。
2、保存训练好的模型:(我用的是DQN训练走迷宫的例子)
torch.save(policy_net.state_dict(), 'dqn_maze_model.pkl')
3、测试模型
policy_net.load_state_dict(torch.load('dqn_maze_model.pkl')) #加载模型
policy_net.eval() # 切换到评估模式
只需要将用到的网络结构的代码拷贝到测试代码中即可,别的不相关的东西不需要拷贝
贴心的附上测试代码:
import time
import torch
import numpy as np
from train_env import Train
from RL_brain import DQN
def main():
env = Train()
# 加载模型
model = DQN(n_states=env.n_states, n_actions=env.n_actions)
policy_net = model.eval_net # 将DQN中的评估网络作为策略网络
try:
policy_net.load_state_dict(torch.load('dqn_maze_model.pkl'))
policy_net.eval() # 切换到评估模式
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
return
state = env.reset()
print(f"Initial state: {state}")
total_reward = 0
done = False
while not done:
time.sleep(0.5)
env.render()
with torch.no_grad():
state_np = np.array(state)
state_tensor = torch.tensor(state_np, dtype=torch.float32)
action = policy_net(state_tensor).argmax().item()
print(f"Selected action: {action}")
next_state, reward, done = env.step(action)
print(f"Next state: {next_state}, Reward: {reward}, Done: {done}")
state = next_state
total_reward += reward
print(f"Total reward: {total_reward}")
env.destroy()
if __name__ == "__main__":
main()
如果不知道torch.save()和torch.load()怎么用的,建议看这个保存提取 | 莫烦Python里面讲的很详细,真滴感谢!