强化学习——Sarsa Lambda找宝藏

目录

在Sarsa的基础上改进的sarsa lambda算法

Sarsa存在的问题

改进方法2:Sarsa Lambda

参考


开始每天被老师抓着写周报,以后想摸鱼都摸不了,心态baozha……

在Sarsa的基础上改进的sarsa lambda算法

算法流程和数学推导就不写了,弄清楚lambda的含义:

  • 如果 lambda = 0, Sarsa-lambda 就是 Sarsa, 只更新获取到 reward 前经历的最后一步.
  • 如果 lambda = 1, Sarsa-lambda 更新的是 获取到 reward 前所有经历的步.

lambda表示想要选择的步数,是一个衰减值

和之前的奖励衰减值一样,lambda是脚步衰减值

Sarsa存在的问题

经过上一次的训练:https://xduwq.blog.csdn.net/article/details/105826501

能发现Sarsa存在这一个很严重的问题:由于Sarsa是一种保守的算法,代价是经常陷入局部最优,也可以称之为过拟合现象,卡在某个步骤畏惧不前,原本如果实验成功,将在1000~3000步数之内完成训练,可是如果陷入了局部最优的话,训练几万步都不会成功,可以认为是实验失败!

改进办法1:降低negative reward的值,可以明显改善实验效果!

这里吧negative reward降低为-0.5!

改进方法2:Sarsa Lambda

主要RL_brain.py进行了改动,其余代码和Sarsa一样!

import numpy as np
import pandas as pd

class RL(object):
    def __init__(self, action_space, learning_rate=0.01,reward_decay=0.9,e_greedy=0.9):
        self.actions = action_space  # a list
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy

        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):

        if state not in self.q_table.index:
            # append new state to q table
            self.q_table = self.q_table.append(
                pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            )

    def choose_action(self, observation):
        self.check_state_exist(observation)
        # action selection
        if np.random.rand() < self.epsilon:
            # choose best action
            state_action = self.q_table.loc[observation, :]
            # some actions may have the same value, randomly choose on in these actions
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # choose random action
            action = np.random.choice(self.actions)
        return action

    def learn(self, *args):
        pass

# 离线学习QLearning
class QLearningTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]

        if s_ != 'terminal':    # next state is not terminal
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # Q-Learning中是选择最大值
        else:
            q_target = r  # next state is terminal
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update


# 在线学习SarsaLambdaTable
class SarsaLambdaTable(RL):
    # 初始化
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
        super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

        self.lambda_ = trace_decay
        self.eligibility_trace = self.q_table.copy()

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # 在qtable中添加新的state
            to_be_append = pd.Series(
                [0] * len(self.actions),
                index = self.q_table.columns,
                name = state
            )
            self.q_table = self.q_table.append(to_be_append)
            self.eligibility_trace = self.eligibility_trace.append(to_be_append)


    # 学习更新参数
    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)  # 检查状态是否存在
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':  # next state is not terminal
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # 直接选择下一个行动的值
        else:
            q_target = r  # next state is terminal
        # self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新值
        error = q_target - q_predict

        # 不可或缺性
        self.eligibility_trace.ix[s,:] *= 0
        self.eligibility_trace.ix[s,a] = 1

        # 更新Q table
        self.q_table += self.lr * error * self.eligibility_trace

        self.eligibility_trace *= self.gamma*self.lambda_

参考

主要复现莫烦Python:https://zhuanlan.zhihu.com/p/24860793

文末再膜一下莫烦dalao

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

沉迷单车的追风少年

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值