强化学习保存训练好的模型并测试(亲测有用!!!)

最近刚刚入坑强化学习,对于新手来说,真的好难!好不容易理解了强化学习算法的的训练过程,但是训练好总要用吧,怎么用呢?想在网上找到训练好的模型如何测试几乎找不到(哭哭),大佬们可能觉得这很简单吧,但是小的我是真不会。不过总不能还没开始就放弃吧,在不断尝试和搜索之后,finally,找到了解决方法。好了话不多说,我来记录一下强化学习到底是如何保存并测试模型的。

前言:听说对于新手来说,pytorch比较好入手,所以以下内容是基于torch的。

1、有一个训练好的模型,这个网上资料很多,如果想要快速熟悉强化学习整个算法流程的话,可以先download的一个试试看。

2、保存训练好的模型:(我用的是DQN训练走迷宫的例子)

torch.save(policy_net.state_dict(), 'dqn_maze_model.pkl')

3、测试模型

policy_net.load_state_dict(torch.load('dqn_maze_model.pkl')) #加载模型
policy_net.eval()  # 切换到评估模式

只需要将用到的网络结构的代码拷贝到测试代码中即可,别的不相关的东西不需要拷贝

贴心的附上测试代码:

import time
import torch
import numpy as np
from train_env import Train
from RL_brain import DQN

def main():
    env = Train()
    # 加载模型
    model = DQN(n_states=env.n_states, n_actions=env.n_actions)
    policy_net = model.eval_net  # 将DQN中的评估网络作为策略网络
    try:
        policy_net.load_state_dict(torch.load('dqn_maze_model.pkl'))
        policy_net.eval()  # 切换到评估模式
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    state = env.reset()
    print(f"Initial state: {state}")
    total_reward = 0
    done = False

    while not done:
        time.sleep(0.5)
        env.render()

        with torch.no_grad():
            state_np = np.array(state)
            state_tensor = torch.tensor(state_np, dtype=torch.float32)
            action = policy_net(state_tensor).argmax().item()
            print(f"Selected action: {action}")

        next_state, reward, done = env.step(action)
        print(f"Next state: {next_state}, Reward: {reward}, Done: {done}")

        state = next_state
        total_reward += reward

    print(f"Total reward: {total_reward}")
    env.destroy()

if __name__ == "__main__":
    main()

如果不知道torch.save()和torch.load()怎么用的,建议看这个保存提取 | 莫烦Python里面讲的很详细,真滴感谢!

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值