这一系列的整理为本人在学习强化学习算法过程中的一个笔记与总结,可能会较为浅显且部分有误,参考了其他人的经验总结。
强化学习部分基础算法总结(Q-learning DQN PG AC DDPG TD3) - 知乎 (zhihu.com)
1. Q-learning
1.1 Q-learning 基本原理
q-learning是一种off-policy算法,通过q-table来选择对应的action更新q-table中的q-value值。
可用于离散的和连续的 action space (离散和连续的概念,通俗来讲离散就是指的这个行为非0即1或者上下左右, 而连续则是类似于机器人控制任务)
伪代码如下所示:重点在选择action和更新q-table两步
贪婪算法是q-learning中非常重要的一个概念,贪婪算法是为了平衡探索和利用,即如果达到了贪婪的阈值,那么网络会从已有的q-table中选择最优的q-value所对应的动作;若没有达到贪婪的阈值,则会随机的从动作空间 (action space) 中进行采样。
epsilon-greedy:to balance exploration and expolitation [① random actions exploration ② highest Q-value expolitation]
1.2 代码说明
1.2.1 Q-table
Q-table: 在q-learning中最重要的就是q-table, 通过一个table得到q值和目标q
q-table格式如下,行代表state,列代表action,行列的值代表的是q-value
self.q_table = pd.DataFrame(
columns=self.actions,
dtype=np.float
形象的画一个表如下:
A1 | A2 | A3 | ... | |
---|---|---|---|---|
S1 | Q11 | Q12 | Q13 | ... |
S2 | Q21 | Q22 | Q23 | ... |
S3 | Q31 | Q32 | Q33 | ... |
... | ... | ... | ... | ... |
每次遇到新的state的都会多出来一行,所有action对应的q值
def check_state_exist(self, state):
if state not in self.q_table.index:
self.q_table = self.q_table.append(
pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state
)
)
1.2.2 greedy
-greedy 在选择action的时候起作用,目的如前文所述,greedy threshold确定当前是更倾向探索还是最优值利用。简单来说,在训练前期重视探索,在训练后期则侧重于利用。
为什么?在初始状况下,在q-table中每一个state对应的action的q-value都是0,如果一直没有更新到的话就一直为零但是更新到了就会发生变化。
def select_action(self, observation):
self.check_state_exist(observation)
if np.random.uniform() < self.epsilon:
state_actions = self.q_table.loc[observation, :]
action = np.random.choice(state_actions[state_actions == np.max(state_actions)].index)
else:
action = np.random.choice(self.actions)
return action
注意,在小于贪婪值时,会在q-table中选择最优值,首先state_actions从q_table中选择所有给定observation对应的action的q值,action则是找到与这些q值中最大的那些值的索引,即action,并随机进行抽选。
1.2.3 更新 q-table
根据Bellman equation可以得到
其中 为q_target,即目标的q值与当前奖励
还有未来的状态
的最优动作对应的最大q-value有关。(可以想象为当前的这一步的q值和当前这一步的奖励,以及未来的评判有关(肯定是这一步的s和a执行了之后,得到的s_可以有很好的结果是最好的))
为q_predict, 预测的q值和当前这一时刻的动作
还有状态
有关,通过查表可得。
更新q-table,目的为了让q_predict越接近q_target越好
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'end':
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
2. Sarsa
2.1 Sarsa 基本原理
Sarsa则是一种on-policy算法,更新Q-value的时候考虑的不再是最大值而是对应的q值
In the SARSA algorithm, the Q-value is updated taking into account the action, A1, performed in the state, S1. In Q-learning, the action with the highest Q-value in the next state, S1, is used to update the Q-table.
伪代码如下:
2.2 代码说明
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'end':
q_target = r + self.gamma * self.q_table.loc[s_, a_]
else:
q_target = r
error = q_target - q_predict
self.eligibility_trace.loc[s, :] *= 0
self.eligibility_trace.loc[s, a] = 1
self.q_table += self.lr * error * self.eligibility_trace
self.eligibility_trace *= self.gamma * self.lambda_
3. 完整代码
import torch
import numpy as np
import pandas as pd
class QLearning:
def __init__(self, actions, lr, reward_dacay, e_greedy):
self.actions = actions
self.lr = lr
self.gamma = reward_dacay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(
columns=self.actions,
dtype=np.float
)
def select_action(self, observation):
self.check_state_exist(observation)
if np.random.uniform() < self.epsilon:
state_actions = self.q_table.loc[observation, :]
action = np.random.choice(state_actions[state_actions == np.max(state_actions)].index)
else:
action = np.random.choice(self.actions)
return action
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'end':
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
def check_state_exist(self, state):
if state not in self.q_table.index:
self.q_table = self.q_table.append(
pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state
)
)
class Sarsa:
"""
Sarsa Lambda
"""
def __init__(self, actions, learning_rate, reward_decay, e_greedy, sarsa_lambda):
self.actions = actions
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(
columns=self.actions,
dtype=np.float
)
self.lambda_ = sarsa_lambda
self.eligibility_trace = self.q_table.copy()
def select_action(self, observation):
self.check_state_exist(observation)
if np.random.uniform() < self.epsilon:
state_actions = self.q_table.loc[observation, :]
action = np.random.choice(state_actions[state_actions == np.max(state_actions)].index)
else:
action = np.random.choice(self.actions)
return action
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'end':
q_target = r + self.gamma * self.q_table.loc[s_, a_]
else:
q_target = r
error = q_target - q_predict
self.eligibility_trace.loc[s, :] *= 0
self.eligibility_trace.loc[s, a] = 1
self.q_table += self.lr * error * self.eligibility_trace
self.eligibility_trace *= self.gamma * self.lambda_
def check_state_exist(self, state):
if state not in self.q_table.index:
to_be_append = pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state
)
self.q_table = self.q_table.append(to_be_append)
self.eligibility_trace = self.eligibility_trace.append(to_be_append)