import ray
from ray import tune
from ray.rllib.agents.dqn import DQNTrainer
from ray.tune.schedulers import ASHAScheduler
import gym
# 初始化Ray
ray.init(num_cpus=8)
# 设置训练配置
config = {
"log_level": "DEBUG",
"num_workers": 4,
"framework": "torch",
"lr": 2e-3,
"replay_buffer_config": {
"capacity": 10000,
"learning_starts": 500
},
"train_batch_size": 64,
"target_network_update_freq": 10,
"gamma": 0.99,
"env": "CartPole-v1",
}
exp_config = {
"run_or_experiment": DQNTrainer,
"checkpoint_freq": 250,
"checkpoint_at_end": True,
"local_dir": "model_path",
"stop": {"training_iteration": 1000},
"config": config,
"scheduler": ASHAScheduler(metric="episode_reward_mean", mode="max"),
}
# 运行Tune实验进行训练
tune.run(**exp_config)
# 加载模型进行测试
def test_model(checkpoint_path, num_test_episodes=10):
env = gym.make("CartPole-v1")
trainer = DQNTrainer(config=config, env="CartPole-v1")
trainer.restore(checkpoint_path)
total_reward = 0
for episode in range(num_test_episodes):
state = env.reset()
done = False
episode_reward = 0
while not done:
env.render()
action = trainer.compute_single_action(state)
next_state, reward, done, _ = env.step(action)
state = next_state
episode_reward += reward
total_reward += episode_reward
print(f"Episode {episode+1} reward: {episode_reward}")
avg_reward = total_reward / num_test_episodes
print(f"Average reward over {num_test_episodes} test episodes: {avg_reward}")
env.close()
# 假设我们知道最新的checkpoint路径,或者从tune_result中获取
checkpoint_path = "model_path"
test_model(checkpoint_path)
Ray强化框架调用示例
于 2024-08-20 16:39:57 首次发布