强化学习之离散动作采样 vs 连续动作采样
强化学习(Reinforcement Learning, RL)是一种训练智能体(agent)在环境中学习决策策略的方法。不同的任务可能涉及不同类型的动作空间:
- 离散动作空间:动作是有限集合中的一个类别,例如 Atari 游戏(左、右、跳等)。
- 连续动作空间:动作是一个连续值,如机器人控制中的关节角度或汽车转向角。
本文介绍强化学习中常见的离散和连续动作采样方法,并分析如何计算动作采样的对数概率(log probability, logp)(有了logp才能求导优化)。
1. 离散动作采样
在离散动作空间中,我们通常使用 分类分布(Categorical Distribution) 进行采样。假设策略网络输出一个动作概率分布 π ( a ∣ s ) \pi(a|s) π(a∣s),其中 a a a 是离散动作, s s s 是当前状态。
1.1 采样方法
策略网络通常使用 softmax 层输出每个动作的概率:
P
(
a
i
∣
s
)
=
e
z
i
∑
j
e
z
j
P(a_i | s) = \frac{e^{z_i}}{\sum_{j} e^{z_j}}
P(ai∣s)=∑jezjezi
其中
z
i
z_i
zi 是策略网络对动作
a
i
a_i
ai 计算的 logits。
采样动作:
- 计算 softmax 概率分布。
- 从该概率分布中采样一个动作 a ∼ π ( a ∣ s ) a \sim \pi(a|s) a∼π(a∣s)。
在 PyTorch 中实现如下:
import torch
import torch.nn.functional as F
def sample_discrete_action(logits):
probs = F.softmax(logits, dim=-1) # 计算 softmax 概率
dist = torch.distributions.Categorical(probs)
action = dist.sample()
logp = dist.log_prob(action)
return action, logp
1.2 计算 logp
对于离散动作,logp 的计算为:
log
P
(
a
∣
s
)
=
log
e
z
a
∑
j
e
z
j
=
z
a
−
log
∑
j
e
z
j
\log P(a | s) = \log \frac{e^{z_a}}{\sum_{j} e^{z_j}} = z_a - \log \sum_{j} e^{z_j}
logP(a∣s)=log∑jezjeza=za−logj∑ezj
在代码中,我们可以直接使用 torch.distributions.Categorical.log_prob(action)
计算 logp。
2. 连续动作采样
对于连续动作空间,常见的方法是使用 高斯分布(Gaussian Distribution) 进行采样。假设策略网络输出动作的均值 μ ( s ) \mu(s) μ(s) 和标准差 σ ( s ) \sigma(s) σ(s)。
2.1 采样方法
假设策略网络参数化一个正态分布:
a
∼
N
(
μ
,
σ
2
)
a \sim \mathcal{N}(\mu, \sigma^2)
a∼N(μ,σ2)
在 PyTorch 中的实现如下:
def sample_continuous_action(mu, log_std):
# mu, log_std 保证最后一维为动作空间
std = torch.exp(log_std) # 计算标准差
dist = torch.distributions.Normal(mu, std)
action = dist.rsample() # 使用 reparameterization trick 进行采样
logp = dist.log_prob(action).sum(dim=-1) # 计算 logp
return action, logp
这里使用 rsample()
进行采样,它支持重参数化技巧(reparameterization trick),重参数化技巧如下式所示:
a
=
μ
+
σ
⋅
ϵ
,
ϵ
∼
N
(
0
,
1
)
a = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0,1)
a=μ+σ⋅ϵ,ϵ∼N(0,1)
2.2 重参数化技巧的作用
在连续动作空间中,每个具体的动作 a a a 作为一个点的概率密度在数学上趋近于零,因此不能直接计算其概率。为了解决这个问题,我们利用概率密度函数(PDF)来估计动作的发生概率。重参数化技巧(Reparameterization Trick)通过引入一个可微的随机变量(如标准正态分布的噪声 ϵ \epsilon ϵ),使得采样过程变得可导,从而可以通过梯度下降优化策略。这在策略梯度方法和深度强化学习(如 PPO 和 SAC)中尤为重要。
2.3 计算 logp
a
a
a服从均值为
μ
\mu
μ,标准差为
σ
\sigma
σ的高斯分布,取log之后的概率计算如下:
log
P
(
a
∣
s
)
=
−
1
2
(
(
a
−
μ
)
2
σ
2
+
2
log
σ
+
log
2
π
)
\log P(a | s) = -\frac{1}{2} \left( \frac{(a - \mu)^2}{\sigma^2} + 2 \log \sigma + \log 2\pi \right)
logP(a∣s)=−21(σ2(a−μ)2+2logσ+log2π)
在代码中,dist.log_prob(action)
直接计算这个值,并且如果是多维连续动作空间,需要对所有维度求和 (.sum(dim=-1)
)。
2.4 连续动作空间的范围限制
在许多实际应用中,连续动作的取值需要限制在一定范围内,例如:
- 机器人关节的旋转角度必须在 [-π, π] 之间。
- 车辆的加速度或转向角应限制在合理范围。
为了实现这一点,通常在策略网络的输出上应用 tanh 函数,使得动作被限制在 [-1, 1] 的范围内:
def sample_bounded_continuous_action(mu, log_std, action_low, action_high):
std = torch.exp(log_std)
dist = torch.distributions.Normal(mu, std)
raw_action = dist.rsample()
bounded_action = torch.tanh(raw_action) # 将动作限制在 (-1, 1)
scaled_action = action_low + (action_high - action_low) * (bounded_action + 1) / 2 # 重新缩放
logp = dist.log_prob(raw_action).sum(dim=-1) - torch.log(1 - bounded_action.pow(2) + 1e-6).sum(dim=-1) # 计算 logp
return scaled_action, logp
这里的 tanh
限制了动作范围,同时需要对 logp 进行修正,加入
log
(
1
−
tanh
2
(
a
)
)
\log(1 - \tanh^2(a))
log(1−tanh2(a)) 项,以保持梯度的正确性。
log P ( a ′ ∣ s ) = log P ( a ∣ s ) − ∑ i log ( 1 − tanh 2 ( a i ) ) \log P(a' | s) = \log P(a | s) - \sum_{i} \log \left( 1 - \tanh^2(a_i) \right) logP(a′∣s)=logP(a∣s)−i∑log(1−tanh2(ai))
推导过程:
由于
a
′
a'
a′ 由
a
a
a 通过 tanh
变换得到,我们可以使用变换分布的概率密度函数公式:
P
(
a
′
)
=
P
(
a
)
∣
d
a
d
a
′
∣
P(a') = P(a) \left| \frac{da}{da'} \right|
P(a′)=P(a)∣∣∣∣da′da∣∣∣∣
其中,tanh
变换的导数为:
d
a
′
d
a
=
1
−
tanh
2
(
a
)
=
1
−
a
′
2
\frac{da'}{da} = 1 - \tanh^2(a) = 1 - a'^2
dada′=1−tanh2(a)=1−a′2
因此,Jacobian 修正项(绝对值的倒数)为:
∣
d
a
d
a
′
∣
=
1
1
−
a
′
2
\left| \frac{da}{da'} \right| = \frac{1}{1 - a'^2}
∣∣∣∣da′da∣∣∣∣=1−a′21
代入后:
P
(
a
′
)
=
P
(
a
)
⋅
1
1
−
a
′
2
P(a') = P(a) \cdot \frac{1}{1 - a'^2}
P(a′)=P(a)⋅1−a′21
变换分布的概率密度函数公式直观解释:
由于概率密度表示单位长度上的概率,当变量 X X X变换为 Y Y Y时,单位长度可能会被拉伸或压缩,因此密度需要乘以导数的绝对值进行调整,以确保总概率保持不变。
3. 离散 vs 连续:对比分析
特性 | 离散动作采样 | 连续动作采样 |
---|---|---|
典型分布 | 分类分布(Categorical) | 高斯分布(Normal) |
采样方式 | 直接从 softmax 采样 | 采样自 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2) |
logp 计算 | 使用 softmax 计算 | 高斯分布公式计算 |
适用场景 | 游戏策略、象棋、Atari | 机器人控制、自动驾驶 |
训练难度 | 相对容易优化 | 可能需要熵惩罚防止塌陷(SAC) |
总结来说,
- 离散动作空间适用于选择性决策任务,如游戏和离散控制问题。
- 连续动作空间适用于机器人和物理模拟环境。
- 计算 logp 时,离散使用 softmax + Categorical 分布,而连续使用正态分布的 log 公式。