强化学习离线强化学习CQL模型详解

一、离线强化学习 vs 在线强化学习

在这里插入图片描述

1. 在线强化学习 (Online RL)
  • 核心机制:智能体通过与环境实时交互收集数据,并基于这些交互数据动态更新策略
  • 优点:能持续探索环境,适应动态变化。
  • 缺点:需要高昂的交互成本,且可能因探索不足导致策略次优。
2. 离线强化学习 (Offline RL)
  • 核心机制:智能体仅使用预先收集的静态数据集 D \mathcal{D} D)进行训练,不与环境交互
  • 优点:适用于高风险或高成本场景(如医疗、自动驾驶)。
  • 难点
  • 分布偏移 (Distribution Shift):策略生成的行动可能偏离数据集中的分布,导致Q值高估。
  • 外推误差 (Extrapolation Error):对未见过的状态-动作对 ( s , a ) (s,a) (s,a)的Q值估计不可靠,且Q leaning的max操作容易放大Q值,overestimation问题。
  • 保守性困境:需在利用数据与避免冒险之间权衡。

二、CQL (Conservative Q-Learning) 原理

1. 核心思想

CQL通过约束Q函数的估计范围,防止对未见动作的Q值高估,从而学习一个保守的Q函数

2. 目标函数设计

CQL的损失函数包含两部分:

  • 标准Q学习损失:最小化贝尔曼误差。
  • 保守正则项:约束Q值不过分偏离数据集支持的动作。

完整损失函数
L ( θ ) = E ( s , a , r , s ′ ) ∼ D [ ( Q θ ( s , a ) − ( r + γ max ⁡ a ′ Q θ ˉ ( s ′ , a ′ ) ) ) 2 ] + α ⋅ ( E s ∼ D [ log ⁡ ∑ a exp ⁡ ( Q θ ( s , a ) ) ] − E ( s , a ) ∼ D [ Q θ ( s , a ) ] ) \mathcal{L}(\theta) = \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \left[ \left( Q_\theta(s,a) - \left( r + \gamma \max_{a'} Q_{\bar{\theta}}(s',a') \right) \right)^2 \right] + \alpha \cdot \left( \mathbb{E}_{s \sim \mathcal{D}} \left[ \log \sum_a \exp(Q_\theta(s,a)) \right] - \mathbb{E}_{(s,a) \sim \mathcal{D}} \left[ Q_\theta(s,a) \right] \right) L(θ)=E(s,a,r,s)D[(Qθ(s,a)(r+γamaxQθˉ(s,a)))2]+α(EsD[logaexp(Qθ(s,a))]E(s,a)D[Qθ(s,a)])
其中 α \alpha α是权衡系数。


三、CQL保守下界的推导

1. 动机

在离线RL中,直接最大化Q值会导致对数据分布外的动作过估计。CQL通过最小化Q值的上界,构造一个保守的Q函数。

2. 关键步骤
  • 约束条件:要求Q函数在数据分布外的动作上给出较低估值

  • 数学形式:通过引入正则项,使得Q函数满足:
    E s ∼ D [ E a ∼ π ( a ∣ s ) [ Q ( s , a ) ] − E a ∼ D ( a ∣ s ) [ Q ( s , a ) ] ] ≤ 0 \mathbb{E}_{s \sim \mathcal{D}} \left[ \mathbb{E}_{a \sim \pi(a|s)} [Q(s,a)] - \mathbb{E}_{a \sim \mathcal{D}(a|s)} [Q(s,a)] \right] \leq 0 EsD[Eaπ(as)[Q(s,a)]EaD(as)[Q(s,a)]]0
    即策略 π \pi π的期望Q值不超过数据分布 D \mathcal{D} D的期望Q值。

  • 策略 π \pi π的期望Q值:
    假设静态数据集 D \mathcal{D} D采样的状态是分布稳定的,从当前策略 π \pi π计算动作概率分布
    E s ∼ D [ E a ∼ π ( a ∣ s ) [ Q ( s , a ) ] ] \mathbb{E}_{s \sim \mathcal{D}}\left[\mathbb{E}_{a \sim \pi(a|s)} [Q(s,a)]\right] EsD[Eaπ(as)[Q(s,a)]]

  • 数据分布 D \mathcal{D} D的期望Q值
    状态动作对 ( s , a ) (s,a) (s,a)服从数据集 D \mathcal{D} D的分布,即动作 a a a和采样时的行为策略有关
    E ( s , a ) ∼ D [ Q ( s , a ) ] \mathbb{E}_{(s,a) \sim \mathcal{D}} [Q(s,a)] E(s,a)D[Q(s,a)]

3. 推导保守下界
3.1. 原始约束优化问题

离线RL的关键挑战是:策略改进时,Q函数可能对数据分布外(OOD)的动作产生高估。CQL通过约束Q函数,使得策略π下的Q期望不超过数据分布D下的Q期望,构建如下优化问题:

min ⁡ Q L ( Q ) ⏟ 标准Q学习损失 s.t. E s ∼ D [ E a ∼ π [ Q ( s , a ) ] − E a ∼ D [ Q ( s , a ) ] ] ≤ ϵ \min_Q \underbrace{\mathcal{L}(Q)}_{\text{标准Q学习损失}} \quad \text{s.t.} \quad \mathbb{E}_{s \sim \mathcal{D}} \left[ \mathbb{E}_{a \sim \pi}[Q(s,a)] - \mathbb{E}_{a \sim \mathcal{D}}[Q(s,a)] \right] \leq \epsilon Qmin标准Q学习损失 L(Q)s.t.EsD[Eaπ[Q(s,a)]EaD[Q(s,a)]]ϵ

  • 目标:最小化Q函数的标准TD误差(如贝尔曼误差)。
  • 约束:策略π的Q期望与数据分布D的Q期望之差不超过阈值ε,防止OOD动作的高估。

3.2. 拉格朗日松弛法

将约束转化为正则项,引入拉格朗日乘子μ,得到松弛后的目标函数:

min ⁡ Q max ⁡ μ ≥ 0 L ( Q ) + μ ( E s ∼ D [ E a ∼ π [ Q ( s , a ) ] − E a ∼ D [ Q ( s , a ) ] ] − ϵ ) \min_Q \max_{\mu \geq 0} \mathcal{L}(Q) + \mu \left( \mathbb{E}_{s \sim \mathcal{D}} \left[ \mathbb{E}_{a \sim \pi}[Q(s,a)] - \mathbb{E}_{a \sim \mathcal{D}}[Q(s,a)] \right] - \epsilon \right) Qminμ0maxL(Q)+μ(EsD[Eaπ[Q(s,a)]EaD[Q(s,a)]]ϵ)

  • 物理意义:通过调整μ,平衡Q函数的学习与保守性约束。

3.3. 策略期望的变分近似

直接计算策略π的期望 E a ∼ π [ Q ( s , a ) ] \mathbb{E}_{a \sim \pi}[Q(s,a)] Eaπ[Q(s,a)] 需要显式策略π,但在离线设置中π是未知的。CQL通过以下技巧近似:

  • 假设策略π为均匀分布:即对所有动作赋予相同概率,此时策略期望退化为平均Q值。
  • 引入熵正则化:假设π是最大化熵的策略,即 π ( a ∣ s ) ∝ exp ⁡ ( Q ( s , a ) ) \pi(a|s) \propto \exp(Q(s,a)) π(as)exp(Q(s,a)),此时策略期望可解析表达为:
    E a ∼ π [ Q ( s , a ) ] = log ⁡ ∑ a exp ⁡ ( Q ( s , a ) ) − 常数 \mathbb{E}_{a \sim \pi}[Q(s,a)] = \log \sum_a \exp(Q(s,a)) - \text{常数} Eaπ[Q(s,a)]=logaexp(Q(s,a))常数

3.4. 凸对偶性与正则项简化

通过凸对偶性(Convex Duality),将策略期望的约束转化为显式正则项:

  • 关键步骤:将拉格朗日乘子μ的优化问题转化为对偶形式。

  • 推导过程

    (1). 假设策略π为均匀分布,最大化 E a ∼ π [ Q ( s , a ) ] \mathbb{E}_{a \sim \pi}[Q(s,a)] Eaπ[Q(s,a)] 等价于最大化 log ⁡ ∑ a exp ⁡ ( Q ( s , a ) ) \log \sum_a \exp(Q(s,a)) logaexp(Q(s,a))(即LogSumExp函数)。

    (2). 通过凸共轭函数(Convex Conjugate),将约束项转化为正则项:
    μ ( log ⁡ ∑ a exp ⁡ ( Q ( s , a ) ) − E ( s , a ) ∼ D [ Q ( s , a ) ] ) \mu \left( \log \sum_a \exp(Q(s,a)) - \mathbb{E}_{(s,a) \sim \mathcal{D}}[Q(s,a)] \right) μ(logaexp(Q(s,a))E(s,a)D[Q(s,a)])

    (3). 忽略常数项后,正则项简化为:
    E s ∼ D [ log ⁡ ∑ a exp ⁡ ( Q ( s , a ) ) ] − E ( s , a ) ∼ D [ Q ( s , a ) ] \mathbb{E}_{s \sim \mathcal{D}} \left[ \log \sum_a \exp(Q(s,a)) \right] - \mathbb{E}_{(s,a) \sim \mathcal{D}}[Q(s,a)] EsD[logaexp(Q(s,a))]E(s,a)D[Q(s,a)]

3.5. LogSumExp的物理意义
  • 第一项 log ⁡ ∑ exp ⁡ ( Q ( s , a ) ) \log \sum \exp(Q(s,a)) logexp(Q(s,a))
    表示所有动作Q值的软最大值(Soft Maximum),即对Q值的上界估计。

  • 作用:惩罚Q函数在任意动作(包括OOD动作)上的高估值。

  • 第二项 E D [ Q ( s , a ) ] \mathbb{E}_\mathcal{D}[Q(s,a)] ED[Q(s,a)]
    表示数据分布内动作的Q值均值。

  • 作用:保留数据分布内动作的Q值真实性。

  • 差值
    通过拉大“所有动作Q值的上界”与“数据分布内Q值”的差距,显式约束Q函数在OOD动作上的高估。

4. 最终损失函数

将正则项与贝尔曼误差结合,得到CQL的损失函数:
L ( θ ) = 贝尔曼误差 + α ⋅ ( 正则项 ) \mathcal{L}(\theta) = \text{贝尔曼误差} + \alpha \cdot \left( \text{正则项} \right) L(θ)=贝尔曼误差+α(正则项)
通过调节 α \alpha α,控制保守程度。


四、CQL的意义

  • 理论保证:CQL的Q函数是真实Q值的下界,避免了高估风险。
  • 实践效果:在离线数据集上显著优于传统Q-learning,尤其在数据覆盖不足时。
  • 灵活性:可与其他RL算法(如SAC、DQN)结合使用。

公式总结

  • 贝尔曼误差
    E ( s , a , r , s ′ ) [ ( Q ( s , a ) − ( r + γ max ⁡ a ′ Q ( s ′ , a ′ ) ) ) 2 ] \mathbb{E}_{(s,a,r,s')} \left[ \left( Q(s,a) - (r + \gamma \max_{a'} Q(s',a')) \right)^2 \right] E(s,a,r,s)[(Q(s,a)(r+γamaxQ(s,a)))2]
  • CQL正则项
    α ⋅ ( E s [ log ⁡ ∑ a exp ⁡ ( Q ( s , a ) ) ] − E ( s , a ) [ Q ( s , a ) ] ) \alpha \cdot \left( \mathbb{E}_s \left[ \log \sum_a \exp(Q(s,a)) \right] - \mathbb{E}_{(s,a)} [Q(s,a)] \right) α(Es[logaexp(Q(s,a))]E(s,a)[Q(s,a)])

通过这种设计,CQL在离线RL中实现了保守且鲁棒的Q值估计。

五、代码demo

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Q网络定义
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class CQL:
    def __init__(self, state_dim, action_dim, device='cuda', alpha=1.0, gamma=0.99, lr=3e-4):
        self.q_net = QNetwork(state_dim, action_dim).to(device)
        self.target_q_net = QNetwork(state_dim, action_dim).to(device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.gamma = gamma
        self.alpha = alpha  # 保守项的系数
        self.device = device
    
    def update(self, batch):
        states, actions, rewards, next_states, dones = batch
        
        # 计算贝尔曼误差
        with torch.no_grad():
            next_q = self.target_q_net(next_states)
            max_next_q = next_q.max(1)[0]
            target_q = rewards + (1 - dones) * self.gamma * max_next_q
        
        current_q = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        bellman_loss = nn.MSELoss()(current_q, target_q)

        # 计算CQL保守项
        # 计算所有动作的Q值,并求logsumexp
        q_values = self.q_net(states) 
        q_logsumexp = torch.logsumexp(q_values, dim=1).mean()  # 对每个状态的所有动作取logsumexp

        # 数据集中动作的Q值均值
        data_q = self.q_net(states).gather(1, actions.unsqueeze(1)).mean()
        
        conservative_loss = self.alpha * (q_logsumexp - data_q)
        
        # 总损失
        total_loss = bellman_loss + conservative_loss
        
        # 梯度更新
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        # 更新目标网络
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        
        return total_loss.item()

# 使用示例
state_dim = 4
action_dim = 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

agent = CQL(state_dim, action_dim, device=device)

# 假设有一个离线数据集batch
batch = (
    torch.randn(256, state_dim).to(device),      # states
    torch.randint(0, action_dim, (256,)).to(device), # actions
    torch.randn(256).to(device),                 # rewards
    torch.randn(256, state_dim).to(device),      # next_states
    torch.zeros(256).to(device)                  # dones
)

loss = agent.update(batch)
print(f"Total Loss: {loss}")

六、代码细节和扩展

上述代码假设动作空间是离散的,Q函数的输入是state,输出是在每个action上的Q值。
如果是连续动作空间,则需要做一些改动。首先连续动作空间需要使用DDPG,Actor- Critic结构,Actor是Policy网络,输入state,输出在continuous action space的分布。Critic是Q网络,且输入是 ( s t a t e ∥ a c t i o n ) (state \Vert action) (stateaction),输出是当前状态动作对的Q值。此时计算CQL的正则化损失需要在动作空间进行采样操作,因为无法枚举action计算logsumexp了。

# 计算CQL
# 每个状态采样10个动作(均匀采样或策略采样)
sampled_actions = torch.rand(batch_size * 10, action_dim) * 2 - 1  # 示例:均匀采样

# 计算Q值
repeated_states = states.unsqueeze(1).repeat(1, 10, 1).view(-1, state_dim)
q_values = self.q_net(repeated_states, sampled_actions)  # 形状:(batch_size*10, 1)
q_logsumexp = torch.logsumexp(q_values, dim=0).mean()     # 跨所有动作样本计算

离散 vs. 连续动作空间的处理差异

  • 离散动作空间

    • 直接枚举所有动作:若动作空间离散(如 action_dim=2),Q网络输出所有动作的Q值,无需额外采样。

    • 无需状态重复:每个状态只需计算一次Q值矩阵,直接对所有动作取 logsumexp。

  • 连续动作空间

    • 无法枚举动作:需通过采样近似估计 logsumexp。

    • 每个状态需采样多个动作:对每个状态生成多个动作样本(如10个),计算它们的Q值并取 logsumexp。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

贝塔西塔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值