[TD3]算法简介、代码分析以及教你改代码

非常优秀的论作,建议去看。这里写一些我所收货的知识以及知识的归纳。参考链接:1.《浅谈TD3:从算法原理到代码实现》2. 《【深度强化学习】TD3算法:DDPG的进化》3. 《强化学习之TD3算法实现》4. 《论文笔记之TD3算法》—牛!5. 论文《Fujimoto, Scott, Herke van Hoof, and Dave Meger. “Addressing Function Approximation Error in Actor-Critic Methods.》的原文,下载原文一
摘要由CSDN通过智能技术生成

非常优秀的论作,建议去看。这里写一些我所收货的知识以及知识的归纳。如果我写的让你满意,请点个赞👍
参考链接:
1.《浅谈TD3:从算法原理到代码实现》
2. 《【深度强化学习】TD3算法:DDPG的进化》
3. 《强化学习之TD3算法实现》
4. 《论文笔记之TD3算法》—牛!
5. 论文《Fujimoto, Scott, Herke van Hoof, and Dave Meger. “Addressing Function Approximation Error in Actor-Critic Methods.》的原文,下载原文

一、简介与目标

TD3:Twin Delayed Deep Deterministic Policy Gradient ,双延迟深度确定性策略梯度。TD3算法是一个对DDPG优化的版本,即TD3也是一种基于AC架构的面向连续动作空间的DRL算法。
TD3这篇文章主要解决两个问题,一个是overestimation bias,另一个是high variance。

1.1 overestiamtion

从DQN–>DDPG–>TD3都是基于价值学习的算法,(价值函数趋于真实函数,两者之间始终存在误差,该误差会导致Q值高估次优策略的原因—个人理解)。

算法DQN-overestiamtion的原因
  1. 过估计是指估计的值函数比真实的值函数大。
  2. DQN存在过估计的原因:

DQN是一种off-policy的方法,每次学习时,不是使用下一次交互的真实动作,而是使用当前认为价值最大的动作来更新目标值函数,所以会出现对Q值的过高估计。基于函数逼近方法的值函数更新公式可以看出。

  1. 解决方法
    Hasselt提出了Double Q Learning方法,将此方法应用到DQN中,就是Double DQN,即DDQN。
    所谓的Double Q Learning是将动作的选择动作的评估分别用不同的值函数来实现。DDQN借鉴了Double Q-learning的思想,将选取action和估计value分别在predict network 和 target network网络上计算,有效优化了DQN的Q-Value过高估计问题
算法DDPG也存在overestiamtion
  1. actor-critic算法在这篇文章中被发现也存在overestiamtion
  2. 解决方法:TD3的思路.
    使用 两套网络(Twin) 表示不同的Q值,通过选取最小的那个作为我们更新的目标(Target Q Value)抑制持续地过高估计
  3. 两者区别:DDPG算法涉及了4个网络,所以TD3需要用到6个网络

1.2 high variance(高方差)

解决思路:

  1. Actor更新的Delay。
    也就是说相对于Critic可以更新多次后,Actor再进行更新。如果Q能稳定下来再学习policy,应该就会减少一些错误的更新;所以,我们可以把Critic的更新频率,调的比Actor要高一点。让critic更加确定,actor再行动.(对应的是学习率的不同)

  2. Target Policy Smoothing Regularization.目标策略网络的平滑正则化
    在TD3中,目标Q-的更新方式:
    y ← r + γ min ⁡ i − 1 , 2 Q θ i ′ ( s ’ , a ~ ) y \leftarrow r + \gamma \min_{i-1,2} Q_{\theta '_i }(s’,\tilde{a}) yr+γi1,2minQθi(s,a~)
    平滑的方式:
    a ~ ← π ϕ ‘ ( s ’ ) + ϵ , ϵ ∼ c l i p ( N ( 0 , σ ~ ) , − c , c ) \tilde{a} \leftarrow \pi_{\phi ‘}(s’)+\epsilon, \epsilon \sim clip(\mathcal{N}(0, \tilde{\sigma}), -c,c) a~πϕ(s)+ϵ,ϵclip(N(0,σ~),c,c)
    对policy网络引入随机噪声,以期达到对policy波动的稳定性。

二、算法原理

算法流程:
在这里插入图片描述
参考大佬
在这里插入图片描述

三、代码分析及修改

原文的数据是基于s.shape=[B,F]。而我的数据是state.shape=[B,N,F].

3.1 网络基础结构

Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.f1 = nn.Linear(state_dim, 256)
        self.f2 = nn.Linear(256, 128)
        self.f3 = nn.Linear(128, action_dim)
        self.max_a
  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
TD3算法是一种强化学习算法,主要用于解决连续控制问题,它在深度强化学习中具有很高的实用性。下面是一个简单的MATLAB实现: ``` % TD3算法实现 % 请注意,这是一个简单的代码示例,可能需要根据实际情况进行修 % 环境初始化 env = rlPredefinedEnv('Pendulum-Continuous'); obsInfo = getObservationInfo(env); actInfo = getActionInfo(env); % 神经网络参数初始化 actorNetwork = [ imageInputLayer([obsInfo.Dimension(1) 1 1],'Normalization','none','Name','observation') fullyConnectedLayer(256,'Name','fc1') reluLayer('Name','relu1') fullyConnectedLayer(256,'Name','fc2') reluLayer('Name','relu2') fullyConnectedLayer(actInfo.Dimension(1),'Name','actorOutput') tanhLayer('Name','actorTanh')]; criticNetwork = [ imageInputLayer([obsInfo.Dimension(1) 1 1],'Normalization','none','Name','observation') fullyConnectedLayer(256,'Name','fc1') reluLayer('Name','relu1') fullyConnectedLayer(256,'Name','fc2') reluLayer('Name','relu2') fullyConnectedLayer(1,'Name','criticOutput')]; actorOpts = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1); criticOpts = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1); actor = rlDeterministicActorRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'observation'},'Action',{'actorTanh'},actorOpts); critic = rlValueRepresentation(criticNetwork,obsInfo,criticOpts,'Observation',{'observation'}); % TD3算法参数初始化 agentOpts = rlTD3AgentOptions; agentOpts.SampleTime = 0.01; agentOpts.DiscountFactor = 0.99; agentOpts.ExperienceBufferLength = 1e6; agentOpts.TargetSmoothFactor = 5e-3; agentOpts.NoiseOptions.Variance = 0.2; agentOpts.NoiseOptions.VarianceDecayRate = 1e-5; agentOpts.NoiseOptions.StepSize = 0.01; % 创建TD3代理 agent = rlTD3Agent(actor,critic,agentOpts); % 训练代理 trainOpts = rlTrainingOptions; trainOpts.MaxEpisodes = 500; trainOpts.MaxStepsPerEpisode = ceil(env.Ts/env.StepSize); trainOpts.ScoreAveragingWindowLength = 10; trainOpts.StopTrainingCriteria = 'AverageReward'; trainOpts.StopTrainingValue = -100; trainOpts.SaveAgentCriteria = 'EpisodeReward'; trainOpts.SaveAgentValue = -100; trainOpts.Plots = 'training-progress'; trainOpts.Verbose = false; % 训练代理 trainingStats = train(agent,env,trainOpts); % 测试代理 simOptions.ResetFcn = @(in) setVariable(env,in,env.ResetFcn()); simOptions.StopTime = 20; experience = sim(env,agent,simOptions); ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值