基于SARSA的出租车调用🤔
更多代码: gitee主页:https://gitee.com/GZHzzz
博客主页: CSDN:https://blog.csdn.net/gzhzzaa
写在前面
- 作为一个新手,写这个强化学习-基础知识专栏是想和大家分享一下自己强化学习的学习历程,希望大家互相交流一起进步!😁在我的gitee收集了强化学习经典论文:强化学习经典论文,搭建了基于pytorch的典型智能体模型,大家一起多篇多交流,互相学习啊!😊
出租车调用
- Gym库的Taxi-v2环境实现了出租车调度问题的环境。导入环境后,可以用env.reset()来初始化环境,用env.step()来执行一步,用env.render()来显示当前局势。env.render()会打印出的局势图,其中乘客的位置、目的地会用彩色字母显示,出租车的位置会高亮显示。具体而言,如果乘客不在车上,乘客等待地点(位置)的字母会显示为蓝色。目的地所在的字母会显示为洋红色。如果乘客不在车上,出租车所在的位置会用黄色高亮;如果乘客在车上,出租车所在的位置会用绿色高亮。
show me code, no bb
import numpy as np
np.random.seed(0)
import pandas as pd
import matplotlib.pyplot as plt
import gym
# 环境使用
env = gym.make('Taxi-v3')
env.seed(0)
print('观察空间 = {}'.format(env.observation_space))
print('动作空间 = {}'.format(env.action_space))
print('状态数量 = {}'.format(env.observation_space.n))
print('动作数量 = {}'.format(env.action_space.n))
state = env.reset()
taxirow, taxicol, passloc, destidx = env.unwrapped.decode(state)
print(taxirow, taxicol, passloc, destidx)
print('的士位置 = {}'.format((taxirow, taxicol)))
print('乘客位置 = {}'.format(env.unwrapped.locs[passloc]))
print('目标位置 = {}'.format(env.unwrapped.locs[destidx]))
env.render()
env.step(0)
env.render()
# SARSA
class SARSAAgent:
def __init__(self, env, gamma=0.9, learning_rate=0.2, epsilon=.01):
self.gamma = gamma
self.learning_rate = learning_rate
self.epsilon = epsilon
self.action_n = env.action_space.n
self.q = np.zeros((env.observation_space.n, env.action_space.n))
def decide(self, state):
if np.random.uniform() > self.epsilon:
action = self.q[state].argmax()
else:
action = np.random.randint(self.action_n)
return action
def learn(self, state, action, reward, next_state, done, next_action):
u = reward + self.gamma * \
self.q[next_state, next_action] * (1. - done)
td_error = u - self.q[state, action]
self.q[state, action] += self.learning_rate * td_error
def play_sarsa(env, agent, train=False, render=False):
episode_reward = 0
observation = env.reset()
action = agent.decide(observation)
while True:
if render:
env.render()
next_observation, reward, done, _ = env.step(action)
episode_reward += reward
next_action = agent.decide(next_observation) # 终止状态时此步无意义
if train:
agent.learn(observation, action, reward, next_observation,
done, next_action)
if done:
break
observation, action = next_observation, next_action
return episode_reward
agent = SARSAAgent(env)
# 训练
episodes = 3000
episode_rewards = []
for episode in range(episodes):
episode_reward = play_sarsa(env, agent, train=True)
episode_rewards.append(episode_reward)
plt.plot(episode_rewards)
# 测试
agent.epsilon = 0. # 取消探索
episode_rewards = [play_sarsa(env, agent) for _ in range(100)]
print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),
len(episode_rewards), np.mean(episode_rewards)))
# 显示最优价值估计
print(pd.DataFrame(agent.q))
# 显示最优策略估计
policy = np.eye(agent.action_n)[agent.q.argmax(axis=-1)]
print(pd.DataFrame(policy))
- 代码全部亲自跑过,你懂的!😝
结果展示
写在最后
十年磨剑,与君共勉!
更多代码:gitee主页:https://gitee.com/GZHzzz
博客主页:CSDN:https://blog.csdn.net/gzhzzaa
- Fighting!😎
基于pytorch的经典模型:基于pytorch的典型智能体模型
强化学习经典论文:强化学习经典论文
while True:
Go life