Sarsa算法
Sarsa算法是基于Q learning算法的,不同的是,Q learning在更新s1状态的Q表时,计算Q(s1,a2)现实时,会选择s2状态下的最优值,即最有可能会获得奖励的行为,但当他实际到s2状态时并不一定会选择最优行为。而Sarsa是行动派,即计算现实时选择了什么行为,那么他到s2状态时就会选择此行为,因此计算Q(s1,a2)现实的公式没有γ * maxQ(s2),而是γ * Q(s2,action)。
可以看出,每次更新状态时,即进入下一个状态,会把前面选的动作a’也一并更新,即s<-s’,a<-a’.
从上面的公式可以看出,Sarsa和Q learning主要的区别是更新s1的Q表时,s2选择的行为到底是真的还是假设。因此主要改动在learn更新Q表过程。
Q learning算法代码点击获取
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
对比Q learning算法代码,计算现实值时:
#Sarsa
q_target = r + self.gamma * self.q_table.loc[s_, a_]
#Q learning
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
loc[s_, :].max()与loc[s_, a_] 一目了然。
在更新Q表的主循环中:
def S_update():
for episode in range(100):
# initial observation
observation = env.reset()
# RL choose action based on observation
action = RL.choose_action(str(observation))
while True:
# fresh env
env.render()
# RL take action and get next observation and reward
observation_, reward, done = env.step(action)
action_ = RL.choose_action(str(observation_))
# RL learn from this transition
RL.learn(str(observation), action, reward, str(observation_),action_)
# swap observation
observation = observation_
action = action_
# break while loop when end of this episode
if done:
break
不同于Q learning,Sarsa不需要每次寻找开始时选择行为,只有第一次先选择,只有每次行为都是之前计算现实值选好了的(即更新Q表前,看会在s2状态时选择什么行为,用于计算s1的现实值)。
补充:为什么说Q learning比Sarsa胆大?
如上图,假如此时在状态s1,对于Q learning而言,现在要更新s1状态的Q值了,假如选择向下的行为,即下一次会到达s2状态,对于s2状态而言,Qmax(s2)会是x,即在计算s1状态选择向下走行为的现实值时,永远会使用x,那么会导致这个现实值很大,即Q(s1,向下)值很大,那么在下一次训练中,当处于s1状态时,仍然会更多可能进入s2状态,而Q learning并不会将这个最优行为记录,而是在下一次寻找时重新选择,即进入s2状态后重新选择,那么就可能选择向右进入y,此时就非常危险了!!!
对于Sarsa而言,在计算s1状态选择向下走行为的现实值时,会根据选择的行为来计算,而不是一股脑按照最优,那么,只有算法选择向右进入y后,那么就会导致Q(s1,向下)较低,在下一次训练中,当处于s1状态时,就不一定会再进入s2状态,避免了进入y的危险性。
总结一下就是,在计算Q(s1,向下)时,Q learning永远觉得x这儿更近,更可能得到奖励,但实际到s2后可能进入y而“game over”,而这不会影响Q(s1,向下)值。Sarsa在训练几次后觉得s2状态旁边的y太危险,会导致Q(s1,向下)较低,因此之后就可能会绕开s2选择其他安全路径(该路径可能不是最优路径)
Q learning选择最优路径,但危险性高,Sarsa选择最安全路径,但不是最优。
Sarsa(lambda)算法
Sarsa算法是单步更新,即在更新Q表时,得到奖励后只会更新其最近的状态的Q值,而无法保存前面的步骤的Q值,只能多次循环逐渐找到路径。如下,我们从s1状态向下到s2,再向左得到奖励,在更新Q表时,只有Q(s2,向左)有记录,而Q(s1,向下)仍然为0,只能在下一次尝试时才可能更新Q(s1,向下),显然浪费了之前尝试的结果,那有什么方法在得到奖励后把从起点到终点的路径都更新到Q表中呢?
如下,Sarsa(0)代表单步更新,Sarsa(n)代表n次的回合更新,那么中间的Sarsa(λ)则是不确定步数的更新。
下面是单步更新和回合更新的解释,单步更新只会记录最近状态,回合更新则会全部都记录。
λ:脚步衰减值
λ是属于【0,1】的数字,代表衰减度,即离奖励点越远的状态,衰减的越厉害,我们关注点重要程度越小。
这是算法的伪代码,看不懂没关系,看代码更容易理解!!!
下面第一种情况是单步更新,即每次到达一个点后记录一下,如果没有奖励立马就失效了
第二种情况是没有限制的回合更新,每次到达这个点记录一次,然后按照λ逐渐衰减,再次到达该点继续在衰减的基础上再记录一次(记录可以理解为+1的操作,代表访问了一次),长时间未到达会慢慢衰减为0
第三种是有限制的回合更新,即加上限(例如上限为1),其余和第二种情况一样
看不懂可以先看下面的代码讲解!!!
完整代码获取
在Sarsa算法上面修改,主要改动在更新Q表的learn函数里:
改动learn之前,我们需要一个表来记录我们走过的状态信息,我们将其命名为eligibility_trace ,则在初始化时可以直接复制Q表,因为这两个表结构是一样的:
self.eligibility_trace = self.q_table.copy()
那么多了一个位置E表,那么在每次检查state是否已经存在时,也应该对E表操作:
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
to_be_append = pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
self.q_table = self.q_table.append(to_be_append)
# 同样往E表中添加没有的坐标信息
self.eligibility_trace = self.eligibility_trace.append(to_be_append)
接下来就是修改learn过程了!
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
error = q_target - q_predict
# increase trace amount for visited state-action pair
# Method 1:
# self.eligibility_trace.loc[s, a] += 1
# Method 2:
self.eligibility_trace.loc[s, :] *= 0
self.eligibility_trace.loc[s, a] = 1
# Q update
self.q_table += self.lr * error * self.eligibility_trace
# decay eligibility trace after update
self.eligibility_trace *= self.gamma*self.lambda_
前面计算预估值和现实值都没变化,改变的是更新Q表的过程。
注意:前面讲的回合更新并不是说找到奖励后再更新Q表,而是每次移动后都会更新,只不过每次更新的是整个Q表,而非单个Q值
上面代码的Methon1和Methon2就是有限制(上限为1)和无限制,更新Q表时
self.q_table += self.lr * error * self.eligibility_trace
利用E表信息实现全部路径的更新
更新完了后要对E表的值进行衰减:
self.eligibility_trace *= self.gamma*self.lambda_
即乘λ(λ<1)
下面是得到奖励前的Q表(第一个)和E表(第二个):
这是得到奖励对Q表进行更新后:
可以看出,整个路径的Q表值都得到了更新,并且,离奖励越远,衰减的越厉害!
附:可能有人会疑问,为什么最后一次前Q表都没有值?
因为第一次寻找时,前面的预估值和现实值都为0,即差距error为0,所以更新Q表时:
self.q_table += self.lr * error * self.eligibility_trace
一直为0,当得到奖励后,error不为0,根据E表就可以更新全部路径的Q值!