chapter 21 策略梯度算法

本章主要讲解2部分内容:“策略梯度算法原理”;Tensorflow下“策略梯度算法”的实现;

一、策略梯度算法

前面几章讲的Q Learning ,SARSA,DQN都是计算一个Q[state,action],然后根据Q[state,:].argmax()来决定智能体在state下的action,本章讲的“策略梯度算法”则是直接input智能体的state,output智能体当下应该做出的action。
“策略梯度算法”的核心是一个“策略梯度网络”,其目标是通过训练样本,学习“策略梯度网络”的参数,从而获得“策略梯度网络”,当向其input智能体state后,其能给出智能体需做的action。

本章以Cartpole游戏为例,讲解“策略梯度算法”的训练过程:
1、Cartpole游戏简介
Cartpole游戏中有2个物体,一个是Cart,一个是pole,玩家可以通过移动Cart控制pole的角度,游戏的目的是尽可能保持pole直立,当pole倾斜超过一定角度后,游戏将结束。游戏中,保持pole直立得时间越长,获得的reward越多。以下为Cartpole游戏界面:

openAI 开源的Gym提供了一个Cartpole游戏环境:Cartpole-v0,该环境中定义了4个变量:

  • 状态:用一个4维向量表示,该4维向量分别为“平台坐标”和“杆坐标”;
  • 动作:每个state下,需要从2个动作中选择一个执行:[“向左移动”,“向右移动”]。
  • 奖励:当游戏进行时,每个时刻获得一个奖励r。
  • 游戏结束:如果杆倾斜超过15度,或平台移出边缘,则游戏结束。

2、利用Cartpole游戏训练“策略梯度网络”
“策略梯度网络” = {input:智能体状态,output:智能体向右移动的概率}

在实践中,一般使用“交叉熵损失:-sum(lnp(a|s)),来定义Net的损失函数,但是在本案例中,我们并不知道智能体在state下的正确action,只知道其获得多少奖励。
在“策略梯度网络”训练过程中,我们需要为每个action指定一个值A,当A>0时,表示action是正确的,A<0时,表示action是错误的,此时,“策略梯度网络”的损失函数定义为:-Aln(p(a|s)),我们的目标仍然是最小化该损失函数,那么,当A>0时,表示此时的action为正确的,要最大化概率p(a|s),反之,则要最小化概率p(a|s),据此,可以对“策略梯度网络”利用“梯度下降算法”进行训练。
实际中,定义A的方式有很多中,比如,当智能体在这局游戏中win时,我们可以定义训练样本中所有action_i的A_i =1,反之A_i = -1。
一种更常用的A定义方法为“期望奖励方法”:,其中r_t为每个时刻t的reward,gamma为折扣因子,根据该公式,可以计算出训练样本中,智能体在每个state下采取action_i的Ai。

在定义好“策略梯度网络”={input:(平台坐标,杆坐标),output:向右移动的概率,loss:minimize(-sum(Aln(p(a|s))))},我们可以根据训练样本,来进行“策略梯度网络”的学习。
在这里,每一个训练样本train_sample为“一局游戏”,表示为si = {X_i:state_i [平台坐标,杆坐标],label_i:[A_i,action_i]},train_sample=[s1,s2,…,sn],n为智能体在一局游戏中行动的步数。

训练“策略梯度网络”核心思想:
step1:初始化“策略梯度网络”的参数;
step2:利用“策略梯度网络”玩一局游戏,获得train_sample={state,action,A}。
step3:基于train_sample,以及损失函数minimize(-sum(Aln(p(a|s)))),利用“梯度下降方法”,更新“策略梯度网络”的参数,更新“策略梯度网络”。
重复step2,step3不断训练“策略梯度网络”,直到智能体在一局游戏中的reward_sum > threshold为止,停止训练。

二、策略梯度算法的Tensorflow实现

cartpole_pg.py : 利用游戏cartpole来训练一个“策略梯度网络”,使智能体能够通过该“策略梯度网络”自行完cartpole游戏。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Sarah ฅʕ•̫͡•ʔฅ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值