强化学习之TD3算法实现

学习研究强化学习有个必不可少的工具就是——环境Env。在学习初期的时候,我们大多会在简单Gym环境比如Cartpolo、MountainCar、Pendulum、Puckworld、迷宫寻宝这几种环境中实现我们的算法。那么随着算法的复杂提高,我们需要进一步在复杂环境中实现算法。以TD3算法中的7种环境为新一轮起点,进行研究。如果你看过最近几年的论文,你会发现大多数算法都是在MUJOCO环境下实现的,TD3也不例外。
因此复现TD3算法需要

  1. 装有带MUJOCO环境的完整版Gym包(网上有许多教程,都可以来安装,整个过程会稍有点曲折)。
  2. GPU(带CUDA)上训练。
  3. 因为服务器上很难显示Env的动态变化,因此可在虚拟机(CPU即可)上实现测试。
  4. 学习曲线的绘制是在Pycharm上使用pandas包。

环境

  1. 复现只对TD3、DDPG、AHE三种算法。
  2. 只在MUJOCO的7种环境中实现。分别是:
    HalfCheetah-v2:在这里插入图片描述Hopper-v2:
    在这里插入图片描述
    Walker2d-v2:
    在这里插入图片描述
    Ant-v2:
    在这里插入图片描述Reacher-v2:
    在这里插入图片描述
    InvertedPendukum-v2:
    在这里插入图片描述
    InvertedDoublePendulum-v2:
    在这里插入图片描述

TD3算法原理

TD3算法应该重视的原因:

  1. DDPG存在过估计,而TD3引入CDQ技术来减缓这个问题。
  2. DDPG训练不稳定,这是由于AC算法本质造成的,而TD3引入的DP和TPS技术来减缓高方差问题。

在这里插入图片描述
其实TD3算法本质就是:
一个以DDPG为框架,然后把3种减缓估计误差的技巧添加上去,厉害的是,这三种技巧是相辅相成的,也就是说三种技巧的叠加比任意2种、任意1种或是DDPG的效果都要好。具体见我对TD3的详细论文解读
哪三种技巧呢?

  1. CDQ,截断(clip)的双Q学习。本质就是利用Double Q-learning的强解耦性,但是实际上做不到完全独立,为了避免误差进一步扩大,使用取 M i n Min Min操作来减缓过估计。
  2. DP,策略延迟。让Actor网络更新比Critic慢一点,确保Critic比较准确了,即TD error比较小了,再更新Actor,这样可以防止方差太大以及降低了偏差,防止累计误差扩大。这里用到了DQN中的Target网络技术,用于抑制高方差,同时也作为衡量TD error到了何时到了足够小。DP还有一个重要的作用在于缩短了训练时间,因为Actor几乎有一半的时间没有在训练。
  3. TPS,目标策略平滑。本质在于“相近的动作其 Q Q Q值应该是相同的”。故通过对目标策略加入高斯噪声完成对TD目标值的平滑,减小方差。

另外从伪代码可以看出,TD3除了3个技巧以外,其余基本和DDPG类似。

TD3算法实战

TD3作者是给出源代码的,其源代码斌不复杂,按照DDPG的代码以及伪代码就可以看得懂,这里捡几点稍微要注意的地方:

  1. 网络的设置:TD3网络、DDPG网络、our_DDPG网络(为了简化,各自的Target网络均没画,其实都只是深复制一份即可):①TD3网络:在这里插入图片描述
    在这里插入图片描述
    ②③DDPG网络和AHE:Actor部分是一样的;Critic部分属于值函数近似网络第二种中的两种变体:输入是(s,a),输出是值函数。在这里插入图片描述在这里插入图片描述

  2. CDQ实现:

    	# 截断的双Q学习 CDQ
    	target_Q1, target_Q2 = self.critic_target(next_state, next_action)
    	target_Q = torch.min(target_Q1, target_Q2)
    	target_Q = reward + not_done * self.discount * target_Q
    
  3. DP实现:

		if self.total_it % self.policy_freq == 0:
			#目标策略延迟 DP
			# Compute actor losse
			actor_loss = -self.critic.Q1(state, self.actor(state)).mean()		
			# Optimize the actor 
			self.actor_optimizer.zero_grad()
			actor_loss.backward()
			self.actor_optimizer.step()
			# Critic目标策略更新
			for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
				target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
			# Actor目标策略更新
			for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
				target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
  1. TPS实现:
		next_action = (
		# 目标策略平滑 TPS
		self.actor_target(next_state) + noise
		).clamp(-self.max_action, self.max_action)
  1. 这里需要注意的是关于Critic训练的时候,和伪代码中有点不同:
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

这个其实很好理解,就 Q 1 、 Q 2 Q_1、Q_2 Q1Q2同时训练,其实CDQ本来是按照Doubel Q-learning一样,在单个step内,要么只更新 Q 1 Q_1 Q1,要么只更新 Q 2 Q_2 Q2,但为了节约计算资源,就共用1个TD目标值,但经过证明,这样也是可以收敛的,即最终 Q 1 = Q 2 → Q ∗ Q_1=Q_2\to Q^* Q1=Q2Q

TD3原论文中对7种环境进行了7种算法的比较,为了简化这一过程,复现只针对HalfCheetah-v2环境和TD3,即不做算法之间的对比。

复现结果

复现结果主要是视频以及学习曲线:

训练(Training)
训练100万个steps,每过5000个steps,测试一下当前model的学习情况,测试10个Episodes,10次均是由不同的随机种子设定的。我们需要记录三个值:最大累计奖励、最小累计奖励、平均累计奖励(期望累计奖励,可以看成是 Q 真 值 Q_{真值} Q,相当于MC估计)。故整个训练过程一共200组数据,每组数据共3个元素。
通过pandas绘出的学习曲线如下:在这里插入图片描述
Note:

  1. 蓝色实现 Q 真 值 Q_{真值} Q
  2. 红色阴影部分为最大和最小值的差距,代表着方差(Var)。
  3. 从整体来看,Agent学到了最优策略,直到训练100W个学习步为止, Q Q Q值在1W分左右,后续可能还会提升,但幅度应该不大。
  4. Learning curve几乎和论文中一样,可能稳定性上略差一点。
  5. HalfCheetah-v2环境下每个Epsiode固定为1000个steps,即输出done。测试的目的是让Agent持续的前向奔跑,不后退也不摔倒。分数越高表现力越强。

测试(Testing)
测试主要是对训练结束后的model进行测试。测试进行了20个Episode,对每个Episode统计游戏分数,结果如下:
在这里插入图片描述
为了更直观的表现学习成果,接下来以gif形式展现学习前后的Agent:

训练前后视频结果点这里(电脑端若看不了,可用手机端打开)。

实验总结

总体来说TD3是一个不错的算法,特别是在连续动作空间中可以较快的以高表现力学习到最优策略,原文作者提供了开源的代码以及论文,这些资料我都放在我的另一篇TD3论文解读里的开头,有兴趣的可以打开来下载!

  • 10
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值