SARSA与Q-Learning算法实现框架

SARSA与Q-Learning是比较常见的强化学习算法,其具体实现的伪代码以及算法流程本文不详细描述。本文侧重实现过程的一般框架,以Gym中的Taxi-v3为例,采用pytorch实现编程操作。

需要导入的库如下:

import gym
import numpy as np
import matplotlib.pyplot as plt

1.实现出租车状态的解码

​ taxi-v3问题第一步的建模过程是描述了游戏的状态向量集合: X = { ( 出租车横坐标,出租车纵坐标,乘客当前位置,乘客目标位置 ) T : 出租车横坐标 0 ∼ 4 , 出租车纵坐标 0 ∼ 4 , 乘客当前位置 0 ∼ 3 , 乘客目标位置 0 ∼ 4 } X=\{(出租车横坐标,出租车纵坐标,乘客当前位置,乘客目标位置)^T:出租车横坐标0\sim 4,出租车纵坐标0\sim 4,乘客当前位置0\sim 3,乘客目标位置0\sim 4 \} X={(出租车横坐标,出租车纵坐标,乘客当前位置,乘客目标位置)T:出租车横坐标04,出租车纵坐标04,乘客当前位置03,乘客目标位置04} S = { s : 0 , 1 , 2 , . . . 500 } S=\{s:0,1,2,...500\} S={s:0,1,2,...500}两个集合之间的一一对应关系( 5 × 5 × 4 × 5 5 \times 5 \times 4 \times 5 5×5×4×5)。其解码过程代码如下:

def examples():
    env = gym.make("Taxi-v3")
    state = env.reset()
    taxirow,taxicol,passloc,destidx = env.unwrapped.decode(state) # 这一步是异常重要的
    print(taxirow,taxicol,passloc,destidx)
    print("出租车位置:{}".format((taxirow,taxicol)))
    print("乘客位置:{}".format(env.unwrapped.locs[passloc]))
    print("目标位置:{}".format(env.unwrapped.locs[destidx]))
    env.render()

2. 智能体的类设置

​ 该框架将智能体封装成一个大的类,将智能体的决策与学习过程设置为其类下的方法,将众多参数以及 Q π ( s , a ) Q_{\pi}(s,a) Qπ(s,a)设置成其子属性。对于决策过程采用 ε − g r e e d y \varepsilon-greedy εgreedy策略来求得 a ^ = π ( s ) \hat a = \pi(s) a^=π(s)。而对于学习过程而言只列举以下三种情况:

  • 对于SARSA算法 Q π ( S t , A t ) Q_{\pi}(S_t,A_t) Qπ(St,At)的更新策略如下(已知 ( s t , a t , r t , s t + 1 , a t + 1 ) (s_t,a_t,r_t,s_{t+1},a_{t+1}) (st,at,rt,st+1,at+1)):

q ( s t , a t ) ← q ( s t , a t ) + α ( r t + γ q ( s t + 1 , a t + 1 ) − q ( s t , a t ) ) q(s_t,a_t)\leftarrow q(s_t,a_t) + \alpha(r_t+\gamma q(s_{t+1},a_{t+1})-q(s_t,a_t)) q(st,at)q(st,at)+α(rt+γq(st+1,at+1)q(st,at))

# SARSA智能体类
class SARSAAgent:
    def __init__(self,env,gamma=0.9,learning_rate=0.1,epsilon=0.01):
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.action_n = env.action_space.n
        self.q = np.zeros((env.observation_space.n,env.action_space.n))

    # 决策过程
    def decide(self,state):
        if np.random.uniform() > self.epsilon:
            action = self.q[state].argmax()
        else:
            action = np.random.randint(self.action_n)
        return action

    # 学习过程
    def learn(self,state,action,reward,next_state,done,next_action):
        u = reward + self.gamma*self.q[next_state,next_action]*(1 - done)
        td_error = u - self.q[state,action]
        self.q[state,action] += self.learning_rate*td_error
  • 对于期望SARSA算法更新策略如下(已知 ( s t , a t , r t , s t + 1 ) (s_t,a_t,r_t,s_{t+1}) (st,at,rt,st+1)):

q ( s t , a t ) ← q ( s t , a t ) + α ( r t + γ ∑ a ∈ A ( s t + 1 ) π ( a ∣ s t + 1 ) q ( s t + 1 , a ) − q ( s t , a t ) ) q(s_t,a_t)\leftarrow q(s_t,a_t) + \alpha(r_t+\gamma \sum_{a\in A(s_{t+1})}\pi(a|s_{t+1}) q(s_{t+1},a)-q(s_t,a_t)) q(st,at)q(st,at)+α(rt+γaA(st+1)π(ast+1)q(st+1,a)q(st,at))


# 期望SARSA智能体类
class ExpectedSARSAAgent:
    def __init__(self,env,gamma=0.9,learning_rate=0.1,epsilon=0.01):
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.action_n = env.action_space.n
        self.q = np.zeros((env.observation_space.n,env.action_space.n))

    # 决策过程
    def decide(self,state):
        if np.random.uniform() > self.epsilon:
            action = self.q[state].argmax()
        else:
            action = np.random.randint(self.action_n)
        return action

    # 学习过程
    def learn(self,state,action,reward,next_state,done):
        v = (self.q[next_state].sum()*self.epsilon +\
                                     self.q[next_state].max()*(1 - self.epsilon))
        u = reward + self.gamma*v*(1 - done)
        td_error = u - self.q[state,action]
        self.q[state,action] += self.learning_rate*td_error
  • 对于Q-Learning算法更新策略如下(已知 ( s t , a t , r t , s t + 1 ) (s_t,a_t,r_t,s_{t+1}) (st,at,rt,st+1)):

q ( s t , a t ) ← q ( s t , a t ) + α ( r t + γ max ⁡ a ∈ A ( s s + 1 ) q ( s t + 1 , a ) − q ( s t , a t ) ) q(s_t,a_t)\leftarrow q(s_t,a_t) + \alpha(r_t+\gamma\max_{a \in A(s_{s+1})} q(s_{t+1},a)-q(s_t,a_t)) q(st,at)q(st,at)+α(rt+γaA(ss+1)maxq(st+1,a)q(st,at))

# Q-Learning类
class QLearningAgent:
    def __init__(self,env,gamma=0.9,learning_rate=0.1,epsilon=0.01):
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.action_n = env.action_space.n
        self.q = np.zeros((env.observation_space.n,env.action_space.n))

    # 决策过程
    def decide(self,state):
        if np.random.uniform() > self.epsilon:
            action = self.q[state].argmax()
        else:
            action = np.random.randint(self.action_n)
        return action

    # 学习过程
    def learn(self,state,action,reward,next_state,done):
        u = reward + self.gamma*self.q[next_state].max()*(1 - done)
        td_error = u - self.q[state,action]
        self.q[state,action] += self.learning_rate*td_error

3.智能体与环境交互设置

以sarsa算法为例,其他算法可以在此基础上修改。

# 智能体与环境交互一回合
def play_sarsa(env,agent,train=False,render=False):
    episode_reward = 0
    observation = env.reset()
    action = agent.decide(observation)
    while True:
        if render == True:
            env.render()
        next_observation,reward,done,_ = env.step(action)
        episode_reward += reward
        next_action = agent.decide(next_observation)
        if train:
            agent.learn(observation,action,reward,
                        next_observation,done,next_action)
        if done:
            break
        observation,action = next_observation,next_action
    return episode_reward

4.主函数

包括在每个episode中的训练,学习过程。最后用训练好的Q来进行测试(将 ε \varepsilon ε设置为0即为测试)。

以SARSA算法为例:

if __name__=="__main__":
    np.random.seed(0)
    env = gym.make("Taxi-v3")
    state = env.reset()
    agent = SARSAAgent(env)
    # print(env.reset())
    episodes = 5000
    episode_rewards = []
    for epsiode in range(episodes):
        episode_reward = play_sarsa(env,agent,train=True)
        episode_rewards.append(episode_reward)
    plt.plot(episode_rewards)
    # 下面开始测试
    agent.epsilon = 0
    episode_rewards = []
    for i in range(100):
        eposide_reward = play_sarsa(env,agent)
        episode_rewards.append(eposide_reward)
        print("测试第{}次奖励是{}".format(i,eposide_reward))

    print("平均奖励{}".format(np.mean(episode_rewards)))
    print("q={}".format(agent.q))
    plt.waitforbuttonpress()

5.结果

在这里插入图片描述

6.附录:完整代码

import gym
import numpy as np
import matplotlib.pyplot as plt

# SARSA智能体类
class SARSAAgent:
    def __init__(self,env,gamma=0.9,learning_rate=0.1,epsilon=0.01):
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.action_n = env.action_space.n
        self.q = np.zeros((env.observation_space.n,env.action_space.n))

    # 决策过程
    def decide(self,state):
        if np.random.uniform() > self.epsilon:
            action = self.q[state].argmax()
        else:
            action = np.random.randint(self.action_n)
        return action

    # 学习过程
    def learn(self,state,action,reward,next_state,done,next_action):
        u = reward + self.gamma*self.q[next_state,next_action]*(1 - done)
        td_error = u - self.q[state,action]
        self.q[state,action] += self.learning_rate*td_error

# 智能体与环境交互一回合
def play_sarsa(env,agent,train=False,render=False):
    episode_reward = 0
    observation = env.reset()
    action = agent.decide(observation)
    while True:
        if render == True:
            env.render()
        next_observation,reward,done,_ = env.step(action)
        episode_reward += reward
        next_action = agent.decide(next_observation)
        if train:
            agent.learn(observation,action,reward,
                        next_observation,done,next_action)
        if done:
            break
        observation,action = next_observation,next_action
    return episode_reward

def examples():
    env = gym.make("Taxi-v3")
    state = env.reset()
    taxirow,taxicol,passloc,destidx = env.unwrapped.decode(state) # 这一步是异常重要的
    print(taxirow,taxicol,passloc,destidx)
    print("出租车位置:{}".format((taxirow,taxicol)))
    print("乘客位置:{}".format(env.unwrapped.locs[passloc]))
    print("目标位置:{}".format(env.unwrapped.locs[destidx]))
    env.render()

if __name__=="__main__":
    np.random.seed(0)
    env = gym.make("Taxi-v3")
    state = env.reset()
    agent = SARSAAgent(env)
    # print(env.reset())
    episodes = 5000
    episode_rewards = []
    for epsiode in range(episodes):
        episode_reward = play_sarsa(env,agent,train=True)
        episode_rewards.append(episode_reward)
    plt.plot(episode_rewards)
    # 下面开始测试
    agent.epsilon = 0
    episode_rewards = []
    for i in range(100):
        eposide_reward = play_sarsa(env,agent)
        episode_rewards.append(eposide_reward)
        print("测试第{}次奖励是{}".format(i,eposide_reward))

    print("平均奖励{}".format(np.mean(episode_rewards)))
    print("q={}".format(agent.q))
    plt.waitforbuttonpress()

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值