【强化学习笔记】(1) Q-learning, Sarsa

本文总结了强化学习中的几种基础算法,包括Q-learning、DQN、PG、AC、DDPG和TD3,介绍了它们的基本原理、Q-table的使用、epsilon-greedy策略以及Sarsa算法的差异。详细解释了如何通过Q-table更新和选择动作,以及在探索与利用之间的平衡策略。
摘要由CSDN通过智能技术生成

这一系列的整理为本人在学习强化学习算法过程中的一个笔记与总结,可能会较为浅显且部分有误,参考了其他人的经验总结。

强化学习部分基础算法总结(Q-learning DQN PG AC DDPG TD3) - 知乎 (zhihu.com)

1. Q-learning

1.1 Q-learning 基本原理

q-learning是一种off-policy算法,通过q-table来选择对应的action更新q-table中的q-value值。

可用于离散的和连续的 action space (离散和连续的概念,通俗来讲离散就是指的这个行为非0即1或者上下左右, 而连续则是类似于机器人控制任务)

伪代码如下所示:重点在选择action和更新q-table两步

贪婪算法是q-learning中非常重要的一个概念,贪婪算法是为了平衡探索和利用,即如果达到了贪婪的阈值,那么网络会从已有的q-table中选择最优的q-value所对应的动作;若没有达到贪婪的阈值,则会随机的从动作空间 (action space) 中进行采样。

epsilon-greedy:to balance exploration and expolitation [① random actions exploration ② highest Q-value expolitation]

1.2 代码说明

1.2.1 Q-table

Q-table: 在q-learning中最重要的就是q-table, 通过一个table得到q值和目标q

q-table格式如下,行代表state,列代表action,行列的值代表的是q-value

        self.q_table = pd.DataFrame(
            columns=self.actions,
            dtype=np.float

形象的画一个表如下:

A1A2A3...
S1Q11Q12Q13...
S2Q21Q22Q23...
S3Q31Q32Q33...
...............

每次遇到新的state的都会多出来一行,所有action对应的q值

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            self.q_table = self.q_table.append(
                pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state
                )
            )
1.2.2 greedy

\mathbf{\epsilon}-greedy 在选择action的时候起作用,目的如前文所述,greedy threshold确定当前是更倾向探索还是最优值利用。简单来说,在训练前期重视探索,在训练后期则侧重于利用。

为什么?在初始状况下,在q-table中每一个state对应的action的q-value都是0,如果一直没有更新到的话就一直为零但是更新到了就会发生变化。

  def select_action(self, observation):
        self.check_state_exist(observation)
        if np.random.uniform() < self.epsilon:
            state_actions = self.q_table.loc[observation, :]
            action = np.random.choice(state_actions[state_actions == np.max(state_actions)].index)
        else:
            action = np.random.choice(self.actions)
        return action

注意,在小于贪婪值时,会在q-table中选择最优值,首先state_actions从q_table中选择所有给定observation对应的action的q值,action则是找到与这些q值中最大的那些值的索引,即action,并随机进行抽选。

1.2.3 更新 q-table

根据Bellman equation可以得到

Q(s, a) \leftarrow Q(s, a)+\alpha\left[r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime}\right)-Q(s, a)\right]

其中 r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime}\right) 为q_target,即目标的q值与当前奖励r_t还有未来的状态s_{t+1}最优动作对应的最大q-value有关。(可以想象为当前的这一步的q值和当前这一步的奖励,以及未来的评判有关(肯定是这一步的s和a执行了之后,得到的s_可以有很好的结果是最好的))

Q(s, a)q_predict, 预测的q值和当前这一时刻的动作a_t还有状态s_t有关,通过查表可得。

更新q-table,目的为了让q_predict越接近q_target越好

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'end':
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()
        else:
            q_target = r
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)

2. Sarsa

2.1 Sarsa 基本原理

Sarsa则是一种on-policy算法,更新Q-value的时候考虑的不再是最大值而是\left(s_{t+1}, a_{t+1} \right)对应的q值

In the SARSA algorithm, the Q-value is updated taking into account the action, A1, performed in the state, S1. In Q-learning, the action with the highest Q-value in the next state, S1, is used to update the Q-table.

 伪代码如下:

2.2 代码说明

Q(S, A) \leftarrow Q(S, A)+\alpha\left(R+\gamma Q\left(S^{\prime}, A^{\prime}\right)-Q(S, A)\right)

def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'end':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]
        else:
            q_target = r
        error = q_target - q_predict
        self.eligibility_trace.loc[s, :] *= 0
        self.eligibility_trace.loc[s, a] = 1
        self.q_table += self.lr * error * self.eligibility_trace
        self.eligibility_trace *= self.gamma * self.lambda_

3. 完整代码

import torch
import numpy as np
import pandas as pd


class QLearning:
    def __init__(self, actions, lr, reward_dacay, e_greedy):
        self.actions = actions
        self.lr = lr
        self.gamma = reward_dacay
        self.epsilon = e_greedy
        self.q_table = pd.DataFrame(
            columns=self.actions,
            dtype=np.float
        )

    def select_action(self, observation):
        self.check_state_exist(observation)
        if np.random.uniform() < self.epsilon:
            state_actions = self.q_table.loc[observation, :]
            action = np.random.choice(state_actions[state_actions == np.max(state_actions)].index)
        else:
            action = np.random.choice(self.actions)
        return action

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'end':
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()
        else:
            q_target = r
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            self.q_table = self.q_table.append(
                pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state
                )
            )


class Sarsa:
    """
    Sarsa Lambda
    """
    def __init__(self, actions, learning_rate, reward_decay, e_greedy, sarsa_lambda):
        self.actions = actions
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy
        self.q_table = pd.DataFrame(
            columns=self.actions,
            dtype=np.float
        )
        self.lambda_ = sarsa_lambda
        self.eligibility_trace = self.q_table.copy()

    def select_action(self, observation):
        self.check_state_exist(observation)
        if np.random.uniform() < self.epsilon:
            state_actions = self.q_table.loc[observation, :]
            action = np.random.choice(state_actions[state_actions == np.max(state_actions)].index)
        else:
            action = np.random.choice(self.actions)
        return action

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'end':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]
        else:
            q_target = r
        error = q_target - q_predict
        self.eligibility_trace.loc[s, :] *= 0
        self.eligibility_trace.loc[s, a] = 1
        self.q_table += self.lr * error * self.eligibility_trace
        self.eligibility_trace *= self.gamma * self.lambda_

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            to_be_append = pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state
                )
            self.q_table = self.q_table.append(to_be_append)
            self.eligibility_trace = self.eligibility_trace.append(to_be_append)

  • 22
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值