Ray强化框架调用示例

import ray  
from ray import tune  
from ray.rllib.agents.dqn import DQNTrainer  
from ray.tune.schedulers import ASHAScheduler  
import gym  

  
# 初始化Ray  
ray.init(num_cpus=8)       
 
# 设置训练配置  
config = {  
    "log_level": "DEBUG",  
    "num_workers": 4,  
    "framework": "torch",  
    "lr": 2e-3,  
    "replay_buffer_config": {  
        "capacity": 10000,  
        "learning_starts": 500  
    },  
    "train_batch_size": 64,  
    "target_network_update_freq": 10,  
    "gamma": 0.99,  
    "env": "CartPole-v1",  
}
exp_config = {  
    "run_or_experiment": DQNTrainer,  
    "checkpoint_freq": 250,  
    "checkpoint_at_end": True,  
    "local_dir": "model_path",  
    "stop": {"training_iteration": 1000},  
    "config": config,  
    "scheduler": ASHAScheduler(metric="episode_reward_mean", mode="max"),  
} 
# 运行Tune实验进行训练  
tune.run(**exp_config)


  
# 加载模型进行测试  
def test_model(checkpoint_path, num_test_episodes=10):  
    env = gym.make("CartPole-v1")  
    trainer = DQNTrainer(config=config, env="CartPole-v1")  
    trainer.restore(checkpoint_path)  
      
    total_reward = 0  
    for episode in range(num_test_episodes):  
        state = env.reset()  
        done = False  
        episode_reward = 0  
        while not done: 
            env.render() 
            action = trainer.compute_single_action(state)  
            next_state, reward, done, _ = env.step(action)  
            state = next_state  
            episode_reward += reward  
        total_reward += episode_reward  
        print(f"Episode {episode+1} reward: {episode_reward}")  
      
    avg_reward = total_reward / num_test_episodes  
    print(f"Average reward over {num_test_episodes} test episodes: {avg_reward}")  
      
    env.close()  
  
# 假设我们知道最新的checkpoint路径,或者从tune_result中获取  
checkpoint_path = "model_path"  
test_model(checkpoint_path)  

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

doukione

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值