强化学习之离散动作采样 vs 连续动作采样

强化学习之离散动作采样 vs 连续动作采样

强化学习(Reinforcement Learning, RL)是一种训练智能体(agent)在环境中学习决策策略的方法。不同的任务可能涉及不同类型的动作空间:

  • 离散动作空间:动作是有限集合中的一个类别,例如 Atari 游戏(左、右、跳等)。
  • 连续动作空间:动作是一个连续值,如机器人控制中的关节角度或汽车转向角。

本文介绍强化学习中常见的离散和连续动作采样方法,并分析如何计算动作采样的对数概率(log probability, logp)(有了logp才能求导优化)。


1. 离散动作采样

在离散动作空间中,我们通常使用 分类分布(Categorical Distribution) 进行采样。假设策略网络输出一个动作概率分布 π ( a ∣ s ) \pi(a|s) π(as),其中 a a a 是离散动作, s s s 是当前状态。

1.1 采样方法

策略网络通常使用 softmax 层输出每个动作的概率:
P ( a i ∣ s ) = e z i ∑ j e z j P(a_i | s) = \frac{e^{z_i}}{\sum_{j} e^{z_j}} P(ais)=jezjezi
其中 z i z_i zi 是策略网络对动作 a i a_i ai 计算的 logits。

采样动作:

  1. 计算 softmax 概率分布。
  2. 从该概率分布中采样一个动作 a ∼ π ( a ∣ s ) a \sim \pi(a|s) aπ(as)

在 PyTorch 中实现如下:

import torch
import torch.nn.functional as F

def sample_discrete_action(logits):
    probs = F.softmax(logits, dim=-1)  # 计算 softmax 概率
    dist = torch.distributions.Categorical(probs)
    action = dist.sample()
    logp = dist.log_prob(action)
    return action, logp

1.2 计算 logp

对于离散动作,logp 的计算为:
log ⁡ P ( a ∣ s ) = log ⁡ e z a ∑ j e z j = z a − log ⁡ ∑ j e z j \log P(a | s) = \log \frac{e^{z_a}}{\sum_{j} e^{z_j}} = z_a - \log \sum_{j} e^{z_j} logP(as)=logjezjeza=zalogjezj

在代码中,我们可以直接使用 torch.distributions.Categorical.log_prob(action) 计算 logp。


2. 连续动作采样

对于连续动作空间,常见的方法是使用 高斯分布(Gaussian Distribution) 进行采样。假设策略网络输出动作的均值 μ ( s ) \mu(s) μ(s) 和标准差 σ ( s ) \sigma(s) σ(s)

2.1 采样方法

假设策略网络参数化一个正态分布:
a ∼ N ( μ , σ 2 ) a \sim \mathcal{N}(\mu, \sigma^2) aN(μ,σ2)
在 PyTorch 中的实现如下:

def sample_continuous_action(mu, log_std):
	# mu, log_std 保证最后一维为动作空间
    std = torch.exp(log_std)  # 计算标准差
    dist = torch.distributions.Normal(mu, std)
    action = dist.rsample()  # 使用 reparameterization trick 进行采样
    logp = dist.log_prob(action).sum(dim=-1)  # 计算 logp
    return action, logp

这里使用 rsample() 进行采样,它支持重参数化技巧(reparameterization trick),重参数化技巧如下式所示:
a = μ + σ ⋅ ϵ , ϵ ∼ N ( 0 , 1 ) a = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0,1) a=μ+σϵ,ϵN(0,1)

2.2 重参数化技巧的作用

在连续动作空间中,每个具体的动作 a a a 作为一个点的概率密度在数学上趋近于零,因此不能直接计算其概率。为了解决这个问题,我们利用概率密度函数(PDF)来估计动作的发生概率。重参数化技巧(Reparameterization Trick)通过引入一个可微的随机变量(如标准正态分布的噪声 ϵ \epsilon ϵ),使得采样过程变得可导,从而可以通过梯度下降优化策略。这在策略梯度方法和深度强化学习(如 PPO 和 SAC)中尤为重要。

2.3 计算 logp

a a a服从均值为 μ \mu μ,标准差为 σ \sigma σ的高斯分布,取log之后的概率计算如下:
log ⁡ P ( a ∣ s ) = − 1 2 ( ( a − μ ) 2 σ 2 + 2 log ⁡ σ + log ⁡ 2 π ) \log P(a | s) = -\frac{1}{2} \left( \frac{(a - \mu)^2}{\sigma^2} + 2 \log \sigma + \log 2\pi \right) logP(as)=21(σ2(aμ)2+2logσ+log2π)

在代码中,dist.log_prob(action) 直接计算这个值,并且如果是多维连续动作空间,需要对所有维度求和 (.sum(dim=-1))。

2.4 连续动作空间的范围限制

在许多实际应用中,连续动作的取值需要限制在一定范围内,例如:

  • 机器人关节的旋转角度必须在 [-π, π] 之间。
  • 车辆的加速度或转向角应限制在合理范围。

为了实现这一点,通常在策略网络的输出上应用 tanh 函数,使得动作被限制在 [-1, 1] 的范围内:

def sample_bounded_continuous_action(mu, log_std, action_low, action_high):
    std = torch.exp(log_std)
    dist = torch.distributions.Normal(mu, std)
    raw_action = dist.rsample()
    bounded_action = torch.tanh(raw_action)  # 将动作限制在 (-1, 1)
    scaled_action = action_low + (action_high - action_low) * (bounded_action + 1) / 2  # 重新缩放
    logp = dist.log_prob(raw_action).sum(dim=-1) - torch.log(1 - bounded_action.pow(2) + 1e-6).sum(dim=-1)  # 计算 logp
    return scaled_action, logp

这里的 tanh 限制了动作范围,同时需要对 logp 进行修正,加入 log ⁡ ( 1 − tanh ⁡ 2 ( a ) ) \log(1 - \tanh^2(a)) log(1tanh2(a)) 项,以保持梯度的正确性。

log ⁡ P ( a ′ ∣ s ) = log ⁡ P ( a ∣ s ) − ∑ i log ⁡ ( 1 − tanh ⁡ 2 ( a i ) ) \log P(a' | s) = \log P(a | s) - \sum_{i} \log \left( 1 - \tanh^2(a_i) \right) logP(as)=logP(as)ilog(1tanh2(ai))

推导过程:

由于 a ′ a' a a a a 通过 tanh 变换得到,我们可以使用变换分布的概率密度函数公式
P ( a ′ ) = P ( a ) ∣ d a d a ′ ∣ P(a') = P(a) \left| \frac{da}{da'} \right| P(a)=P(a)dada
其中,tanh 变换的导数为:
d a ′ d a = 1 − tanh ⁡ 2 ( a ) = 1 − a ′ 2 \frac{da'}{da} = 1 - \tanh^2(a) = 1 - a'^2 dada=1tanh2(a)=1a2
因此,Jacobian 修正项(绝对值的倒数)为:
∣ d a d a ′ ∣ = 1 1 − a ′ 2 \left| \frac{da}{da'} \right| = \frac{1}{1 - a'^2} dada=1a21
代入后:
P ( a ′ ) = P ( a ) ⋅ 1 1 − a ′ 2 P(a') = P(a) \cdot \frac{1}{1 - a'^2} P(a)=P(a)1a21

变换分布的概率密度函数公式直观解释:

由于概率密度表示单位长度上的概率,当变量 X X X变换为 Y Y Y时,单位长度可能会被拉伸或压缩,因此密度需要乘以导数的绝对值进行调整,以确保总概率保持不变。


3. 离散 vs 连续:对比分析

特性离散动作采样连续动作采样
典型分布分类分布(Categorical)高斯分布(Normal)
采样方式直接从 softmax 采样采样自 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2)
logp 计算使用 softmax 计算高斯分布公式计算
适用场景游戏策略、象棋、Atari机器人控制、自动驾驶
训练难度相对容易优化可能需要熵惩罚防止塌陷(SAC)

总结来说,

  • 离散动作空间适用于选择性决策任务,如游戏和离散控制问题。
  • 连续动作空间适用于机器人和物理模拟环境。
  • 计算 logp 时,离散使用 softmax + Categorical 分布,而连续使用正态分布的 log 公式。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值