强化学习(一)--Sarsa与Q-learning算法
最近实验室有一个项目要用到强化学习,在这开个新坑来记录下强化学习的学习过程。
第一节就先来最简单的基于表格型的RL算法,包括经典的Sarsa和Q-learning算法。
由于时间原因,关于算法的理论知识不再详细介绍,重点是研究怎么编程实现,代码是参考的飞浆PaddlePaddle公开课的代码,下来又自己手撸了一遍。飞浆PaddlePaddle公开课是我认为最适合入门强化学习的公开课,科老师讲解的真的非常清晰,公开课地址。
1. SARSA算法
sarsa算法是最基础的on-policy算法,它采用的是TD单步更新的方式,每一个step都会更新Q表格,Q表格的更新公式为:这也是代码最核心的部分,它就是将Q值不断逼近目标值,也就是未来总收益。
Sarsa的名字就来源于它更新Q表格时所用到的五个参数:S,A,R,S’,A’,它的算法伪代码为:
第一次看伪代码可能会有些懵,公开课里很贴心的给出了流程图:
![在这里插入图片描述](https://img-blog.csdnimg.cn/20210315185651139.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3MzMzMDQ4,size_16,color_FFFFFF,t_70
根据流程图很容易就能编程实现Sarsa算法。
2. Q-learning算法
Q-learning算法则是off-policy算法,与sarsa算法一样都是采用查表的方式,不同的地方在于它的A’默认为最优策略选择的动作,而sarsa的A则是下一个状态要实际执行的动作。
因此,Q-learning算法的Q表更新公式有些不同,可以看到Target_Q使用的是下个状态下最大的Q值来更新Q表格:
Q-learning算法的伪代码,可以看到下一时刻的动作并不一定去执行:
它的流程图,与sarsa的对比就能看出不同:
3. 代码实现
以sarsa算法为例,来讲解一下怎么进行代码实现,这里使用的环境为gym中的CliffWalking,它有四个动作 :0 up, 1 right, 2 down, 3 left。小乌龟每走一步reward = -1,掉入黑色方框内reward=-100,小乌龟被拖到起点重新开始。
3.1主函数
主函数主要承担导入环境,定义智能体,训练及测试。
- 导入的环境是gym中现有的环境,可以直接使用,具体的使用方法可以看看gym库的使用方法。
- 定义的智能体是SarsaAgent,使我们定义的一个类,由五个参数需要设定。
- 训练函数run_episode() 和测试函数test_episode() 也是接下来要实现的。
- 共进行500个episode的训练,每个episode都输出进行多少步和总的reward,每20个episode,我们输出可视化一下。
- 训练结束后我们测试下结果。
# 主函数
def main():
# 导入环境
env = gym.make("CliffWalking-v0")
env = CliffWalkingWapper(env)
# env = gym.make("FrozenLake-v0",is_slippery = False)
# env = FrozenLakeWapper(env)
agent = SarsaAgent(
obs_n = env.observation_space.n,
act_n = env.action_space.n,
learning_rate = 0.1,
gamma = 0.9,
e_greed =