【论文阅读笔记-meta rl】MBML:解决任务推断中的伪相关性

NeurIPS 2020

Li, J., Vuong, Q., Liu, S., Liu, M., Ciosek, K., Christensen, H., & Su, H. (2020). Multi-task batch reinforcement learning with metric learning. Advances in neural information processing systems, 33, 6197-6210.

在多任务强化学习中,我们通常希望训练一个能够泛化到多个任务的策略。然而,当训练数据来自不同任务且它们的状态-动作分布差异较大时,策略可能会学习到错误的任务推断方式——即仅根据状态-动作对判断任务,而忽略奖励信号。这会导致在未见任务上表现不佳。本文提出了一种结合三元组损失过渡重新标记的方法,MBML(Multi-task Batch RL with Metric Learning),通过近似奖励函数对跨任务数据进行重标注,构造"难负样本",再用三元组损失迫使任务推断模块必须依赖奖励信息而非仅靠状态-动作模式,强制任务推断模块同时考虑状态、动作和奖励,从而提升泛化性能。此外,训练好的策略作为初始化可大幅提升后续训练的收敛速度。

一、直观例子:为什么只看状态-动作对会出问题?

假设我们要训练一个智能体在二维平面上导航到不同目标位置的任务。我们有两个训练任务:

  • 任务1:导航到目标位置 Goal 1
  • 任务2:导航到目标位置 Goal 2

我们收集了两个数据集:

  • 任务1的数据(红色方块)主要集中在 Goal 1 周围
  • 任务2的数据(蓝色方块)主要集中在 Goal 2 周围

由于两个目标位置相距较远,红色与蓝色方块在状态-动作空间中没有重叠。如果我们训练一个任务推断模块,它可能学会:

“红色方块 → 任务1,蓝色方块 → 任务2”

而完全忽略了奖励信号(例如距离目标越近奖励越高)。

测试时的问题:在未见任务中(例如真实目标是 Goal 1),智能体随机收集了一些过渡数据(绿色方块)。如果这些绿色方块在状态空间上与蓝色方块更接近,模型会错误地推断当前任务是任务2,导致智能体向错误的目标移动。

这个例子揭示了多任务批强化学习中的一个核心挑战:当训练任务的数据分布差异大时,模型容易学习到伪相关性,仅依赖状态-动作对进行任务推断,而忽略了奖励信号。

二、研究背景与问题定义

强化学习基础

强化学习(RL)中,智能体通过与环境交互学习一个策略,以最大化累积奖励。一个任务通常建模为一个马尔可夫决策过程(MDP):

  • 状态空间 S \mathcal{S} S
  • 动作空间 A \mathcal{A} A
  • 转移函数 T ( s ′ ∣ s , a ) T(s' \mid s, a) T(ss,a)
  • 奖励函数 R ( s , a , s ′ ) R(s, a, s') R(s,a,s)
  • 初始状态分布 T 0 T_0 T0

策略 π ( a ∣ s ) \pi(a \mid s) π(as) 是一个从状态到动作的映射。目标是最优化:
J ( π ) = E τ ∼ π [ ∑ t = 0 H − 1 R ( s t , a t , s t ′ ) ] J(\pi) = \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^{H-1} R(s_t, a_t, s'_t) \right] J(π)=Eτπ[t=0H1R(st,at,st)]

批强化学习

批强化学习(Batch RL)指的是仅使用一个预先收集的离线数据集 B = { ( s t , a t , r t , s t ′ ) } \mathcal{B} = \{(s_t, a_t, r_t, s'_t)\} B={(st,at,rt,st)} 训练策略,而不允许与环境进行额外交互。典型的算法如 BCQ(Batch Constrained Q-Learning),它通过生成候选动作并添加微小扰动来进行受限探索。

多任务批强化学习问题

给定 K K K 个批数据集 { B i } i = 1 K \{\mathcal{B}_i\}_{i=1}^K {Bi}i=1K,每个数据集来自一个不同的任务 M i M_i Mi。我们的目标是训练一个多任务策略 π θ \pi_\theta πθ,使其在从同一任务分布 p ( M ) p(M) p(M) 中采样的未见任务上表现良好。关键挑战是:

  1. 测试时任务身份未知,策略必须从收集的过渡数据中推断任务;
  2. 不同任务的数据集可能在状态-动作分布上差异很大(即分布不重叠),导致任务推断模块可能忽略奖励信号。

在在线多任务RL中,策略可以通过持续采集数据逐步消除分布差异,实现"自纠错"。但在Batch RL中,数据是静态的,任务推断模块一旦学到错误依赖关系就无法修正。这正是本文聚焦的核心难题。

任务推断模块

我们使用一个任务推断模块 q ϕ q_\phi qϕ,它接收一个上下文集 c i \mathbf{c}_i ci(来自任务 i i i 的一组过渡数据),输出一个任务身份的后验分布 q ϕ ( z ∣ c i ) q_\phi(z \mid \mathbf{c}_i) qϕ(zci)。策略则同时接收状态 s s s 和推断出的任务身份 z z z π ( s , z ) \pi(s, z) π(s,z)

三、核心挑战:伪相关性与错误的任务依赖

在训练中,我们希望通过蒸馏多个单任务策略来得到一个多任务策略。具体来说,我们:

  1. 为每个训练任务训练一个 BCQ 策略(得到 Q i , G i , ξ i Q_i, G_i, \xi_i Qi,Gi,ξi);
  2. 将这些策略蒸馏为一个多任务策略( Q D , G D , ξ D Q_D, G_D, \xi_D QD,GD,ξD),其输入除了状态外还包括推断的任务身份 z z z

蒸馏损失函数(以值函数为例)为:
L Q = 1 K ∑ i = 1 K E ( s , a ) , c i ∼ B i [ ( Q i ( s , a ) − Q D ( s , a , z i ) ) 2 + β KL ( q ϕ ( c i ) ∥ N ( 0 , 1 ) ) ] , z i ∼ q ϕ ( c i ) \mathcal{L}_Q = \frac{1}{K} \sum_{i=1}^K \mathbb{E}_{(s,a),\mathbf{c}_i \sim \mathcal{B}_i} \left[ (Q_i(s,a) - Q_D(s,a,z_i))^2 + \beta \text{KL}(q_\phi(\mathbf{c}_i) \| \mathcal{N}(0,1)) \right], \quad z_i \sim q_\phi(\mathbf{c}_i) LQ=K1i=1KE(s,a),ciBi[(Qi(s,a)QD(s,a,zi))2+βKL(qϕ(ci)N(0,1))],ziqϕ(ci)

然而,当不同任务的数据集在状态-动作分布上不重叠时,模型可能学会:
P ( Z ∣ S , A ) 而非正确的 P ( Z ∣ S , A , R ) P(Z \mid S, A) \quad \text{而非正确的} \quad P(Z \mid S, A, R) P(ZS,A)而非正确的P(ZS,A,R)

即,模型仅依赖状态-动作对推断任务,而忽略了奖励。这会导致在测试时,如果收集的过渡数据与某个训练任务的状态-动作分布更接近,即使奖励模式不同,模型也会错误地推断任务身份。

四、方法:MBML —— 通过度量学习增强任务推断

为了解决上述问题,我们提出了 MBML(Multi-task Batch RL with Metric Learning),其核心是三元组损失过渡重新标记

1. 过渡重新标记:构建硬负例

我们为每个训练任务 i i i 学习一个奖励函数近似器 R ^ i \hat{R}_i R^i。给定一个来自任务 j j j 的上下文集 c j \mathbf{c}_j cj,我们用 R ^ i \hat{R}_i R^i 重新标记其中的奖励,得到:
c j → i = { ( s j , t , a j , t , R ^ i ( s j , t , a j , t ) , s j , t ′ ) } \mathbf{c}_{j \to i} = \{(s_{j,t}, a_{j,t}, \hat{R}_i(s_{j,t}, a_{j,t}), s'_{j,t})\} cji={(sj,t,aj,t,R^i(sj,t,aj,t),sj,t)}
这相当于将任务 j j j 的数据“伪装”成任务 i i i 的数据(状态-动作对相同,但奖励不同)。

2. 三元组损失:强制模型考虑奖励

对于每个任务 i i i,我们构建三元组:

  • 锚点(Anchor) c j → i \mathbf{c}_{j \to i} cji(重新标记后的数据)
  • 正例(Positive) c i \mathbf{c}_i ci(原始任务 i i i 的数据)
  • 负例(Negative) c j \mathbf{c}_j cj(原始任务 j j j 的数据)

三元组损失定义为:
L triplet i = 1 K − 1 ∑ j ≠ i [ d ( q ϕ ( c j → i ) , q ϕ ( c i ) ) ⏟ 锚点-正样本距离 − d ( q ϕ ( c j → i ) , q ϕ ( c j ) ) ⏟ 锚点-负样本距离 + m ] + \mathcal{L}_{\text{triplet}}^i = \frac{1}{K-1} \sum_{j\neq i} \Big[ \underbrace{d\big(q_\phi(\mathbf{c}_{j\to i}), q_\phi(\mathbf{c}_i)\big)}_{\text{锚点-正样本距离}} - \underbrace{d\big(q_\phi(\mathbf{c}_{j\to i}), q_\phi(\mathbf{c}_j)\big)}_{\text{锚点-负样本距离}} + m \Big]_+ Ltripleti=K11j=i[锚点-正样本距离 d(qϕ(cji),qϕ(ci))锚点-负样本距离 d(qϕ(cji),qϕ(cj))+m]+
其中:

  • d ( ⋅ , ⋅ ) d(\cdot,\cdot) d(,) 是散度度量,本文使用 KL 散度
  • m > 0 m > 0 m>0 是 margin,确保正样本比负样本至少近 m m m
  • [ ⋅ ] + = max ⁡ ( ⋅ , 0 ) [\cdot]_+ = \max(\cdot, 0) []+=max(,0) 是 ReLU

直观理解

  • 第一项:鼓励 c j → i \mathbf{c}_{j \to i} cji c i \mathbf{c}_i ci 推断出相似的任务身份;
  • 第二项:鼓励 c j → i \mathbf{c}_{j \to i} cji c j \mathbf{c}_j cj 推断出不同的任务身份。

关键点:由于 c j → i \mathbf{c}_{j \to i} cji c j \mathbf{c}_j cj 的状态-动作对完全相同,唯一的区别是奖励。因此,为了最小化该损失,任务推断模块必须考虑奖励信息 P ( Z ∣ S , A , R ) P(Z|S,A,R) P(ZS,A,R)

3. 总损失函数

最终损失为蒸馏损失与三元组损失的加权和:
L = L triplet + L Q + L G + L ξ \mathcal{L} = \mathcal{L}_{\text{triplet}} + \mathcal{L}_Q + \mathcal{L}_G + \mathcal{L}_\xi L=Ltriplet+LQ+LG+Lξ

其中,
L Q = 1 K ∑ i = 1 K E ( s , a ) , c i ∼ B i [ ( Q i ( s , a ) − Q D ( s , a , z i ) ) 2 + β KL ( q ϕ ( c i ) ∥ N ( 0 , 1 ) ) ] , z i ∼ q ϕ ( c i ) \mathcal{L}_Q = \frac{1}{K} \sum_{i=1}^K \mathbb{E}_{(s,a),\mathbf{c}_i \sim \mathcal{B}_i} \left[ (Q_i(s,a) - Q_D(s,a,z_i))^2 + \beta \text{KL}(q_\phi(\mathbf{c}_i) \| \mathcal{N}(0,1)) \right], \quad z_i \sim q_\phi(\mathbf{c}_i) LQ=K1i=1KE(s,a),ciBi[(Qi(s,a)QD(s,a,zi))2+βKL(qϕ(ci)N(0,1))],ziqϕ(ci)

L G = 1 K ∑ i = 1 K E s , c i ∼ B i ∥ G i ( s , ν ) − G D ( s , ν , z ˉ i ) ∥ 2 \mathcal{L}_G = \frac{1}{K} \sum_{i=1}^K \mathbb{E}_{s,\mathbf{c}_i \sim \mathcal{B}_i} \| G_i(s,\nu) - G_D(s,\nu,\bar{\mathbf{z}}_i) \|^2 LG=K1i=1KEs,ciBiGi(s,ν)GD(s,ν,zˉi)2

L ξ = 1 K ∑ i = 1 K E s , c i ∼ B i ν ∼ N ( 0 , 1 ) ∥ ξ i ( s , a ) − ξ D ( s , a , z ˉ i ) ∥ 2 , a = G i ( s , ν ) \mathcal{L}_\xi = \frac{1}{K} \sum_{i=1}^K \mathbb{E}_{\substack{s,\mathbf{c}_i \sim \mathcal{B}_i \\ \nu \sim \mathcal{N}(0,1)}} \| \xi_i(s,a) - \xi_D(s,a,\bar{\mathbf{z}}_i) \|^2, \quad a = G_i(s,\nu) Lξ=K1i=1KEs,ciBiνN(0,1)ξi(s,a)ξD(s,a,zˉi)2,a=Gi(s,ν)

记号 z ˉ i \bar{\mathbf{z}}_i zˉi 表示梯度停止(stop gradient),即 L G \mathcal{L}_G LG L ξ \mathcal{L}_\xi Lξ 不用于更新 q ϕ q_\phi qϕ,避免任务推断模块被生成质量差的动作误导。

4. 算法流程

MBML 分为两个阶段:

  1. 单任务策略训练:使用 BCQ 为每个训练任务训练独立的策略;
  2. 多任务策略蒸馏:结合三元组损失,将单任务策略蒸馏为多任务策略。

详细的伪代码见原文附录 E。

五、实验场景与结果

实验环境

我们在 5 个 MuJoCo 任务分布和 1 个修改后的 D4RL 任务上评估 MBML:

  • AntDir:蚂蚁朝目标方向奔跑
  • HumanoidDir-M:人形机器人朝目标方向奔跑(修改版本,避免平凡解)
  • AntGoal:蚂蚁导航至目标位置
  • UmazeGoal-M:在 U 型迷宫中导航至目标位置
  • HalfCheetahVel:猎豹维持目标速度
  • WalkerParam:通过随机物理参数改变转移函数
任务分布任务定义方式训练任务数测试任务数状态空间挑战点
AntDir目标奔跑方向(120°弧内)108proprioceptive state状态不含方向信息
HumanoidDir-M目标奔跑方向108同上奖励系数调整后任务差异显著
AntGoal目标位置(120°弧内)108同上需导航到不同位置
HalfCheetahVel目标速度108同上速度控制精度
WalkerParam物理参数(质量、摩擦等)308同上转移函数变化
UmazeGoal-M迷宫目标位置108位置+速度稀疏奖励场景

基线方法

  • PEARL(修改为批训练版本)
  • Contextual BCQ(在 BCQ 基础上增加任务推断模块)
  • MetaGenRL(基于 DDPG,在批设定中容易发散)

主要结果

  1. 在未见任务上的表现:MBML 在所有任务分布上均优于基线方法。

    • PEARL 在某些任务上表现尚可,但未针对 Batch RL 的离线特性优化;Contextual BCQ 稳定但收敛到次优解;MetaGenRL 快速发散。
  2. 消融实验

    • Full Model (MBML):完整方法
    • No Relabeling (NR):仅使用原始上下文,通过大批量采样构造难负样本(计算复杂度高)
    • No Triplet Loss (NT):仅将重标注数据加入输入,但无三元组损失
    • Neither:简单蒸馏 + 任务推断模块
    • Ground Truth (GT):使用真实奖励函数重标注(作为性能上限)

    结果

    • 在 5/6 个任务上,完整模型显著优于所有消融版本
    • NR 因计算效率低且难负样本质量差,性能下降
    • NT 虽利用奖励信息,但无显式约束,提升有限
    • Neither 完全失效,验证了三元组损失的必要性
    • GT 与 MBML 性能接近,说明学习的奖励函数足够准确
  3. 作为初始化的加速效果:将训练好的多任务策略用于初始化 SAC,在未见任务上训练时,收敛速度提升高达 80%

分析与讨论

  • 为什么三元组损失有效:它强制模型区分仅奖励不同的过渡数据,从而学习到正确的任务依赖 P ( Z ∣ S , A , R ) P(Z \mid S, A, R) P(ZS,A,R)
  • 计算效率:使用重新标记的三元组损失计算复杂度为 O ( K 2 ) O(K^2) O(K2),而传统硬负例挖掘需要 O ( K 2 N 2 ) O(K^2 N^2) O(K2N2)
  • 奖励预测的泛化性:即使状态-动作分布不重叠,学到的奖励函数也能在一定程度上泛化,足以支撑三元组损失的有效性。

六、总结与展望

本文提出了 MBML,一种通过三元组损失过渡重新标记增强任务推断能力的多任务批强化学习方法。本文的核心在于通过度量学习增强任务推断的鲁棒性,尤其是在数据分布不重叠的批强化学习设定中。该方法不仅提升了泛化性能,还提供了一种高效的策略初始化方案,为实际应用中的样本效率问题提供了有希望的解决方案。核心贡献在于:

  1. 指出了在多任务批 RL 中,由于数据集分布不重叠导致任务推断模块可能忽略奖励信号的问题;
  2. 提出了一种新颖的三元组损失设计,通过奖励重新标记构建硬负例,强制模型考虑奖励;
  3. 实验表明,MBML 在多个任务分布上优于基线方法,并且训练好的策略可作为优质初始化,大幅提升后续训练效率。

Limitation

  • 文章主要讨论了奖励函数差异的情况,其技术细节和公式也集中围绕奖励重标注展开。但对于转移函数(transition function)不同的情况,文章的讨论相对简略。

  • 在原文 Algorithm 3Appendix D 的重标注过程中,仅对奖励进行替换,而保留了原始的下一个状态

  • 文章明确承认这是一个局限)在 Discussion (Sec. 4) 的 Limitations 部分,作者写道:

    “We also assume the learnt reward function of one task can generalize to state-action pairs from the other tasks, even when their state-action visitation frequencies do not overlap significantly.”

    这表明方法的核心假设是奖励函数可跨任务泛化,但未对转移函数的可泛化性做出类似假设

  • 实验中做了转移函数差异任务)文章在 Sec. 5.1 提到:

    “We also consider the WalkerParam environment where random physical parameters parameterize the agent, inducing different transition functions in each task.”

    WalkerParam 在实验中确实被包含,且 MBML 表现良好(+50% 超过 Contextual BCQ)。这说明即使不直接处理转移函数差异,方法依然有效

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值