强化学习笔记之【TD3算法】
前言:
本文为强化学习笔记第三篇,第一篇讲的是Q-learning和DQN,第二篇讲的是DDPG
TD3就是比DDPG多了两个网络用来防止过估计,然后引入了延迟更新机制,就没了,还挺简单的
本文初编辑于2024.10.6
CSDN主页:https://blog.csdn.net/rvdgdsva
博客园主页:https://www.cnblogs.com/hassle
博客园本文链接:
首先,我们需要明确,Q-learning算法发展成DQN算法,DQN算法发展成为DDPG算法,而DDPG算法发展成TD3(Twin Delayed DDPG)算法
这里有两个问题
一、DQN算法中存在过估计
DQN的目标是优化Q值函数。DQN采用离线学习的方式,通过计算Q值的目标来更新其网络参数。DQN的目标公式可以表示为:
y
=
r
+
γ
×
m
a
x
(
Q
(
s
′
,
a
′
;
θ
′
)
)
y = r + γ \times max(Q(s', a'; θ'))
y=r+γ×max(Q(s′,a′;θ′))
在这里,r是当前动作的价值,**max(Q(s’, a’; θ’))**是后续动作中价值最大的动作的价值。
而往往在实际运行中没有办法达成最大价值这种理想情况
r
是当前时刻的奖励。γ
是折扣因子,用于平衡未来奖励的重要性。s'
是下一个状态,a'
是下一个动作。Q(s', a'; θ')
是目标网络的 Q 值。θ'
是目标网络的参数。
于是TD3算法用了一个讨巧但是有效的方式,搞两个网络,分别运行,取相对小价值的作为输出
y
=
r
+
γ
×
m
i
n
(
Q
1
(
s
′
,
π
(
s
′
)
)
,
Q
2
(
s
′
,
π
(
s
′
)
)
)
y = r + γ \times min(Q₁(s', π(s')), Q₂(s', π(s')))
y=r+γ×min(Q1(s′,π(s′)),Q2(s′,π(s′)))
r
是奖励。γ
是折扣因子。Q₁(s', π(s'))
和Q₂(s', π(s'))
是两个目标 Q 网络的 Q 值。π(s')
是目标策略网络生成的下一个动作。θ₁'
和θ₂'
是两个目标 Q 网络的参数,π
是目标策略网络的参数。
TD3 的核心改进在于 使用两个Q网络取最小值 来计算目标Q值,以减少过高估计问题(overestimation bias)。
二、DDPG算法存在局部最优
TD3中的Delay机制,主要体现在两个方面:延迟更新策略网络和延迟更新目标Q网络。这种延迟机制是对经典 DDPG(Deep Deterministic Policy Gradient)算法的一项重要改进
在 DDPG 中,策略网络(Actor)和Q网络(Critic)是交替更新的,这意味着策略网络在每次迭代时都能快速学习新的动作。然而,频繁更新策略网络会使它容易陷入局部最优,难以找到全局最优策略,尤其是在Q值估计不准确的情况下。
2.1 策略网络的优化原理
在强化学习中,策略网络(Actor)决定智能体在每个状态下应该采取的动作,目的是最大化未来的累积奖励。策略网络通过不断调整,学会选择能带来更高回报的动作。
当策略网络更新频繁时,它会在每一轮训练中快速调整自己的权重,期望基于当前的 Q 值(即 Critic 网络评估的动作价值)尽可能找到最优动作。
2.2 局部最优
局部最优是指在当前的策略空间中,智能体找到了一种动作选择,这种选择看起来已经是最好的(最大化了当前的 Q 值),但从更大的全局视角来看,这其实不是最优解。也就是说,策略可能在某个小范围内找到了一个“局部最佳解”,但离真正的全局最优解还有差距。
2.3 频繁更新为什么导致局部最优
当策略网络更新频繁时,可能会过于迅速地朝着当前 Q 网络评估出的“最佳方向”移动,然而,Q 网络本身的估值在早期阶段可能还不够准确或稳定。
- Q值的不稳定性:在强化学习的过程中,Q值估计在训练早期或数据不足时,常常会有误差。如果策略网络过于依赖这些不稳定的 Q 值进行快速更新,它会基于这些错误的估计来选择看似“最优”的动作,这可能让策略陷入一个局部最优,而没有足够时间探索更好的全局解。
- 策略更新速度过快:频繁更新策略网络,会使其迅速调整到一个“看似最优”的策略上。但由于更新速度太快,智能体可能还没有足够的时间探索整个策略空间,因此很容易错过更优的动作选择。
- 缺乏探索:策略网络在频繁更新过程中,可能对当前评估为“好的”动作过度偏好,而忽略了其他动作的探索,这样可能导致陷入局部最优,而未能发现更优的全局策略。
2.4 TD3中的改进
为了避免这种现象,TD3引入了延迟更新策略网络的机制。也就是说,在 Q 网络(Critic)经过多次更新后,策略网络(Actor)才会更新。通过这种延迟更新,策略网络能够基于更准确、更稳定的 Q 值进行更新,从而减少策略过快收敛到局部最优的风险。
三、DDPG算法和TD3算法代码对比
下面对比 DDPG 和 TD3 的代码或伪代码,特别是延迟策略更新和双Q网络的改进部分
另外,TD3在target_actor的输出后面加了个噪声
3.1 DDPG算法
# DDPG
for each iteration:
# 采样经验数据 (state, action, reward, next_state) from replay buffer
batch = replay_buffer.sample()
# Critic 更新 (Q网络)
next_action = target_actor(next_state)
target_q_value = reward + gamma * target_critic(next_state, next_action)
critic_loss = mse(critic(state, action), target_q_value)
critic_optimizer.zero_grad()
critic_loss.backward()
critic_optimizer.step()
# Actor 更新 (策略网络)
actor_loss = -critic(state, actor(state)).mean()
actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()
# 软更新目标网络参数
for param, target_param in zip(critic.parameters(), target_critic.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
for param, target_param in zip(actor.parameters(), target_actor.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
3.2 TD3算法
# TD3
for each iteration:
# 采样经验数据 (state, action, reward, next_state) from replay buffer
batch = replay_buffer.sample()
# Critic 更新 (双 Q 网络)
next_action = target_actor(next_state) + clip(noise, -noise_clip, noise_clip) # 引入噪声
target_q_value1 = target_critic1(next_state, next_action)
target_q_value2 = target_critic2(next_state, next_action)
target_q_value = reward + gamma * min(target_q_value1, target_q_value2) # 取最小值
critic1_loss = mse(critic1(state, action), target_q_value)
critic_optimizer1.zero_grad()
critic1_loss.backward()
critic_optimizer1.step()
critic2_loss = mse(critic2(state, action), target_q_value)
critic_optimizer2.zero_grad()
critic2_loss.backward()
critic_optimizer2.step()
# 延迟更新 Actor 网络
if iteration % policy_delay == 0: # 延迟更新策略网络
actor_loss = -critic1(state, actor(state)).mean()
actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()
# 软更新目标网络参数
for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
for param, target_param in zip(actor.parameters(), target_actor.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
四、总结:DDPG 与 TD3 的关键区别
4.1 双 Q 网络 (Critic)
在 DDPG 中只有一个 Q目标网络,而 TD3 中使用两个 Q 目标网络,主要的改进是通过计算两个 Q 网络的最小值来缓解 Q 值的过估计问题。
4.2 延迟更新策略网络 (Actor)
TD3 的策略网络并不在每次 Q 网络更新后都立即更新,而是隔几次才更新一次。policy_delay
决定了每几次更新 Critic 后更新 Actor 的频率。这样做的目的是为了让 Q 网络先稳定下来,防止 Actor 网络基于不准确的 Q 值进行优化。
4.3 目标策略平滑 (Target Policy Smoothing)
TD3 在生成目标动作时加入噪声,并且通过裁剪噪声来避免过大的扰动。这样可以增加策略的探索性,减少由于策略确定性导致的高估问题。
五、TD3和DDPG网络对比
5.1 DDPG 中的四个网络
- Actor 网络(策略网络):
- 作用:决定给定状态 ss 时,应该采取的动作 a=π(s)a=π(s),目标是找到最大化未来回报的策略。
- 更新:基于 Critic 网络提供的 Q 值更新,以最大化 Critic 估计的 Q 值。
- Target Actor 网络(目标策略网络):
- 作用:为 Critic 网络提供更新目标,目的是让目标 Q 值的更新更为稳定。
- 更新:使用软更新,缓慢向 Actor 网络靠近。
- Critic 网络(Q 网络):
- 作用:估计当前状态 ss 和动作 aa 的 Q 值,即 Q(s,a)Q(s,a),为 Actor 提供优化目标。
- 更新:通过最小化与目标 Q 值的均方误差进行更新。
- Target Critic 网络(目标 Q 网络):
- 作用:生成 Q 值更新的目标,使得 Q 值更新更为稳定,减少振荡。
- 更新:使用软更新,缓慢向 Critic 网络靠近。
DDPG 中的四个网络总结:
- Actor 网络
- Target Actor 网络
- Critic 网络
- Target Critic 网络
5.2 TD3 中的六个网络
TD3 相比 DDPG 增加了两个网络,使得总共有六个网络。多出的网络用于改进 Q 值估计的准确性。
- Actor 网络(策略网络):
- 与 DDPG 相同,决定状态 ss 时采取的动作 a=π(s)a=π(s)。
- Target Actor 网络(目标策略网络):
- 与 DDPG 相同,作为 Actor 网络的目标,更新更为平滑和稳定。
- Critic 1 网络(第一个 Q 网络):
- 估计给定状态 ss 和动作 aa 的 Q 值 Q1(s,a)Q1(s,a)。
- Critic 2 网络(第二个 Q 网络):
- 另一个 Q 网络,估计给定状态 ss 和动作 aa 的 Q 值 Q2(s,a)Q2(s,a)。目标是在 Q 值估计中避免过度高估。
- Target Critic 1 网络(目标 Q 网络 1):
- 作为 Critic 1 网络的目标,类似于 DDPG 中的 Target Critic 网络。
- Target Critic 2 网络(目标 Q 网络 2):
- 作为 Critic 2 网络的目标,用于为 Critic 2 网络生成更稳定的目标值。
TD3 中的六个网络总结:
- Actor 网络
- Target Actor 网络
- Critic 1 网络
- Critic 2 网络
- Target Critic 1 网络
- Target Critic 2 网络
5.3 对比总结
- DDPG 里有四个网络:Actor、Target Actor、Critic、Target Critic。
- TD3 里有六个网络:多了一个 Critic 2 和 Target Critic 2,用于减小 Q 值估计的偏差。