强化学习离线强化学习DecisionTransformer模型详解

Decision Transformer 原理详解


1. 背景与动机

传统强化学习的挑战

  1. 长期依赖建模困难:传统RL算法(如DQN、PPO)在需要长程决策的任务中表现受限。
  2. 样本效率低下:在线交互学习需要大量环境交互,数据利用率低。
  3. 离线学习局限:离线RL方法(如CQL、BCQ)需复杂约束设计,难以直接利用历史轨迹。

Transformer的启示

  • 序列建模优势:Transformer在NLP中展现了对长序列的建模能力,通过自注意力机制捕捉全局依赖。
  • 监督学习范式:将RL问题转化为序列预测任务,利用历史轨迹数据直接生成动作。

2. 核心思想

将强化学习任务建模为条件序列生成问题,使用Transformer模型以自回归方式生成动作序列,条件为历史状态、动作及目标回报(或奖励)。


3. 建模方式

输入表示

将轨迹数据表示为三元组序列,每个时间步包含:
Token t = ( s t , a t , R t ) \text{Token}_t = (s_t, a_t, R_t) Tokent=(st,at,Rt)

  • 状态 s t s_t st):当前环境观测。
  • 动作 a t a_t at):当前执行的动作。
  • 目标回报 R t R_t Rt):从当前时刻到轨迹结束的累计回报(或未来剩余回报)。

序列构建

  • 轨迹分段:对一条轨迹 τ = ( s 0 , a 0 , r 0 , s 1 , . . . , s T ) \tau = (s_0, a_0, r_0, s_1, ..., s_T) τ=(s0,a0,r0,s1,...,sT),计算每个时刻的剩余回报 R t = ∑ k = t T γ k − t r k R_t = \sum_{k=t}^T \gamma^{k-t} r_k Rt=k=tTγktrk
  • 输入序列:按时间步拼接为:
    [ R 0 , s 0 , a 0 , R 1 , s 1 , a 1 , . . . , R T , s T ] [R_0, s_0, a_0, R_1, s_1, a_1, ..., R_T, s_T] [R0,s0,a0,R1,s1,a1,...,RT,sT]

模型结构

在这里插入图片描述

  • Transformer Decoder:使用类似GPT的结构,以因果掩码确保预测仅依赖历史信息。
  • Token嵌入:将 R t , s t , a t R_t, s_t, a_t Rt,st,at 分别嵌入为向量并拼接。
  • 位置编码:添加时间步位置信息。

4. 损失函数

通过最大似然估计(MLE)优化动作预测的准确性:
L = − ∑ t = 0 T log ⁡ P ( a t ∣ R ≤ t , s ≤ t , a < t ) \mathcal{L} = -\sum_{t=0}^T \log P(a_t | R_{\leq t}, s_{\leq t}, a_{<t}) L=t=0TlogP(atRt,st,a<t)

  • 目标:给定历史状态、回报和动作,正确预测当前动作。
  • 实现细节:仅对动作部分的预测计算损失,状态和回报作为条件输入。

5. 训练与推理

在这里插入图片描述

如上图所示:固定图结构下寻找最短路径的示例(左图)被建模为强化学习问题。训练数据集由随机游走轨迹及其每个节点的剩余回报组成(中图)。在给定起始状态且要求每个节点生成最大可能回报的条件下,Decision Transformer能够序列化生成最优路径。

训练阶段

  • 输入:完整轨迹的 [ R t , s t , a t ] [R_t, s_t, a_t] [Rt,st,at] 序列。
  • 掩码机制:使用因果掩码防止未来信息泄漏。
  • 优化目标:最小化动作预测的交叉熵损失。

推理阶段

  • 自回归生成:以初始状态 s 0 s_0 s0 和目标回报 R 0 R_0 R0 为起点,逐步生成动作:
    1. 输入当前序列 [ R 0 , s 0 , a 0 , R 1 , s 1 ] [R_0, s_0, a_0, R_1, s_1] [R0,s0,a0,R1,s1]
    2. 预测下一动作 a 1 a_1 a1
    3. 执行 a 1 a_1 a1,获取新状态 s 2 s_2 s2,更新剩余回报 R 1 = R 0 − r 0 R_1 = R_0 - r_0 R1=R0r0
    4. 重复直到终止。

import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, act_dim, hidden_dim=128, nhead=4, num_layers=3, max_timestep=1000):
        """
        Decision Transformer 模型
        
        参数:
            state_dim (int): 状态维度
            act_dim (int): 动作维度
            hidden_dim (int): 隐藏层维度 (默认: 128)
            nhead (int): Transformer头数 (默认: 4)
            num_layers (int): Transformer层数 (默认: 3)
            max_timestep (int): 最大时间步数 (用于位置编码,默认: 1000)
        """
        super().__init__()
        
        # 嵌入层定义
        self.embed_t = nn.Embedding(max_timestep, hidden_dim)  # 时间步位置嵌入
        self.embed_R = nn.Linear(1, hidden_dim)                # Return-to-go嵌入
        self.embed_s = nn.Linear(state_dim, hidden_dim)        # 状态嵌入
        self.embed_a = nn.Linear(act_dim, hidden_dim)          # 动作嵌入
        
        # Transformer编码器
        encoder_layers = TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward=4*hidden_dim)
        self.transformer = TransformerEncoder(encoder_layers, num_layers)
        
        # 动作预测头
        self.pred_a = nn.Linear(hidden_dim, act_dim)           
        
        # 初始化参数
        self.hidden_dim = hidden_dim
        self.act_dim = act_dim

    def forward(self, R, s, a, timesteps):
        """
        前向传播
        
        参数:
            R (Tensor): Return-to-go序列 [batch_size, seq_len, 1]
            s (Tensor): 状态序列 [batch_size, seq_len, state_dim]
            a (Tensor): 动作序列 [batch_size, seq_len, act_dim]
            timesteps (Tensor): 时间步序列 [batch_size, seq_len]
            
        返回:
            Tensor: 预测动作序列 [batch_size, seq_len, act_dim]
        """
        batch_size, seq_len = R.shape[:2]
        
        # 1. 计算各元素的嵌入
        pos_emb = self.embed_t(timesteps)  # [batch_size, seq_len, hidden_dim]
        
        # 计算各模态嵌入并添加位置信息
        R_emb = self.embed_R(R) + pos_emb  # [batch_size, seq_len, hidden_dim]
        s_emb = self.embed_s(s) + pos_emb  # [batch_size, seq_len, hidden_dim]
        a_emb = self.embed_a(a) + pos_emb  # [batch_size, seq_len, hidden_dim]
        
        # 2. 按(R_1, s_1, a_1, ...)顺序交错堆叠
        input_embeds = torch.stack([R_emb, s_emb, a_emb], dim=2)  # [batch_size, seq_len, 3, hidden_dim]
        input_embeds = input_embeds.view(batch_size, 3*seq_len, self.hidden_dim)  # [batch_size, 3*seq_len, hidden_dim]
        
        # 3. 添加序列维度并应用Transformer
        src = input_embeds.permute(1, 0, 2)  # [seq_len*3, batch_size, hidden_dim]
        
        # 生成因果掩码 (仅允许关注历史信息)
        mask = nn.Transformer.generate_square_subsequent_mask(src.size(0)).to(src.device)
        
        # Transformer编码
        hidden_states = self.transformer(src, mask)  # [seq_len*3, batch_size, hidden_dim]
        hidden_states = hidden_states.permute(1, 0, 2)  # [batch_size, seq_len*3, hidden_dim]
        
        # 4. 提取动作对应的隐藏状态
        a_hidden = hidden_states[:, 2::3]  # 每3个元素取第3个 (对应a的位置)
        
        # 5. 预测动作
        return self.pred_a(a_hidden)  # [batch_size, seq_len, act_dim]

# 示例用法
if __name__ == "__main__":
    # 超参数定义
    state_dim = 10   # 状态维度
    act_dim = 4      # 动作维度
    batch_size = 32  # 批大小
    seq_len = 20     # 序列长度
    
    # 初始化模型
    model = DecisionTransformer(state_dim, act_dim)
    
    # 生成随机输入数据
    R = torch.randn(batch_size, seq_len, 1)       # Return-to-go
    s = torch.randn(batch_size, seq_len, state_dim)  # 状态
    a = torch.randn(batch_size, seq_len, act_dim)    # 动作
    t = torch.randint(0, 1000, (batch_size, seq_len))# 时间步
    
    # 前向传播
    a_pred = model(R, s, a, t)
    print("预测动作形状:", a_pred.shape)  # 应为 [32, 20, 4]
    
    # 训练步骤示例
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()
    
    # 计算损失
    loss = loss_fn(a_pred, a)
    print("初始损失:", loss.item())
    
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 评估循环示例 (需要与环境交互)
    def evaluate(model, env, target_return, K=20):
        model.eval()
        state = env.reset()
        states = [state]
        returns = [target_return]
        actions = []
        timesteps = [0]
        
        done = False
        while not done:
            # 准备输入序列 (保持最近K步)
            R_in = torch.FloatTensor(returns[-K:]).unsqueeze(0)  # [1, min(K,t), 1]
            s_in = torch.FloatTensor(states[-K:]).unsqueeze(0)   # [1, min(K,t), state_dim]
            a_in = torch.FloatTensor(actions[-K:]).unsqueeze(0) if len(actions)>0 else torch.empty(1,0,act_dim)
            t_in = torch.LongTensor(timesteps[-K:]).unsqueeze(0) # [1, min(K,t)]
            
            # 预测动作
            with torch.no_grad():
                a_pred = model(R_in, s_in, a_in, t_in)  # [1, seq_len, act_dim]
                next_action = a_pred[0, -1]  # 取最新预测动作
            
            # 与环境交互
            next_state, reward, done, _ = env.step(next_action.numpy())
            
            # 更新序列
            returns.append(returns[-1] - reward)  # 更新剩余回报
            states.append(next_state)
            actions.append(next_action)
            timesteps.append(len(returns))
        
        return sum(returns)  # 返回累计奖励
  • 代码关键点

    • 嵌入层结构:

      时间步使用可学习的Embedding层

      各模态(R/s/a)使用独立线性变换后加入位置信息

    • 序列构建:

      输入序列按(R_t, s_t, a_t)的顺序交错排列

      使用3倍于原序列长度的维度来处理三种不同模态

    • 注意力机制:

      使用标准Transformer编码器

      通过因果掩码强制实现自回归特性

    • 训练优化:

      使用简单的MSE损失进行动作预测

      支持批量训练和序列截断

    • 推理策略:

      采用自回归生成方式

      维护滑动窗口保持上下文长度

      逐步更新returns-to-go

  • 实际使用时需要注意

    • 输入需要标准化处理

    • 训练数据需要包含高质量的轨迹

    • 目标returns-to-go的设置需要与任务奖励尺度匹配

    • 上下文长度K的选择会影响性能


6. 与传统强化学习的区别

特性Decision Transformer传统RL算法
核心机制序列建模,监督学习直接预测动作基于Bellman方程,动态规划或策略梯度优化
问题建模序列生成(监督学习)动态规划/策略优化(交互学习)
数据利用直接利用离线轨迹,无需交互依赖在线交互或重要性采样
回报处理目标回报作为条件输入通过价值函数隐式建模
策略表示显式生成动作序列策略网络或值函数近似
长期依赖处理自注意力机制捕捉全局依赖依赖循环网络或折扣因子
探索机制依赖数据分布,无显式探索ε-greedy、随机策略等显式探索

7. 优势与局限性

优势

  1. 离线学习友好:直接利用历史数据,无需环境交互。
  2. 长程决策能力:Transformer全局注意力机制有效捕捉长期依赖。
  3. 简化流程:将RL转化为监督学习,避免复杂值函数优化。

局限性

  1. 依赖高质量数据:数据分布偏差会导致策略次优。
  2. 目标回报敏感性:需合理设置目标回报,否则可能生成不合理动作。
  3. 实时性不足:自回归生成速度较慢,难以应对高实时性任务。

8. 实例说明

GridWorld导航任务

  • 目标:智能体从起点到达目标点,最大化累计奖励。
  • Decision Transformer输入
    • 状态 s t s_t st:当前位置坐标。
    • 动作 a t a_t at:移动方向(上/下/左/右)。
    • 目标回报 R t R_t Rt:剩余曼哈顿距离的负值。
  • 输出:预测每一步动作,引导智能体走向目标。

9. Decision Transformer Return-to-go思想扩展

Decision Transformer(DT)中的Return-to-Go(RTG)与自然语言处理(NLP)中的Prompt存在一定的相似性,但它们的核心机制和设计目标有所不同。以下是详细解释:


9.1. RTG与Prompt的类比

相似性
  • 引导生成过程
  • Prompt在NLP中用于指定生成任务的目标(例如“翻译以下句子:…”),直接引导模型生成符合预期的内容。
  • RTG在DT中表示当前轨迹的剩余目标回报(即从当前时刻到轨迹结束的累积奖励),它动态地引导模型生成能够达成该回报的动作序列。
  • 条件输入
    两者都作为模型的输入条件,直接影响模型的输出。RTG可以被视为一种“强化学习领域的Prompt”,因为它通过数值化的目标间接“提示”模型应如何行动。
关键区别
  • 动态性
    RTG在轨迹的每一步会更新(逐渐减少),而传统Prompt通常是静态的文本前缀。
  • 数值化目标
    RTG是数值型的条件信号,直接与强化学习的奖励机制挂钩,而Prompt通常是符号化的语义信息。

9.2. RTG如何影响轨迹生成

训练阶段
  • 监督信号
    在训练时,RTG是基于真实轨迹的剩余回报计算的(例如专家轨迹的累积奖励)。模型通过历史状态、动作和RTG序列,学习预测下一步动作。
  • 目标条件化
    RTG的作用类似于目标条件策略(Goal-Conditioned Policy),使模型能够根据不同的目标回报生成对应的动作。
推断阶段
  • 目标设定
    在推断时,可以通过手动设置RTG的初始值(例如设置更高的目标回报)来尝试引导模型生成更优的轨迹。
  • 诱导机制
    如果模型在训练时见过类似的高回报轨迹,理论上可以通过设置更高的RTG值,诱导模型生成更优动作。这与Prompt工程中通过优化提示词提升生成质量的思想一致。

9.3. 通过优秀RTG诱导更好轨迹的可行性

成功条件
  • 训练数据的覆盖性
    如果模型在训练时见过高回报轨迹对应的RTG值,且能够泛化到更高的RTG值,则可能生成更优轨迹。
  • 模型的泛化能力
    Transformer的注意力机制使其对序列模式具有强捕捉能力。如果RTG与动作的关联模式被正确学习,模型可能对未见过的RTG值产生合理响应。
潜在问题
  • 不匹配风险
    若设置的RTG值远超出训练数据的分布(例如设置一个不切实际的高回报目标),模型可能生成不合理动作(类似NLP中的“幻觉”)。(扩展:HAC通过subgoal testing trasition策略(Hindsight数据增强技巧)对底层策略能力不匹配的subgoal进行惩罚。)
  • 局部最优陷阱
    DT基于历史数据模仿策略,可能无法像在线强化学习算法(如PPO)那样主动探索新策略,导致对RTG的响应受限。

9.4. 实验验证与改进方法

实验观察

在这里插入图片描述

上图展示了在指定目标(期望)回报条件下,由决策变换器(Decision Transformer)累积的采样(评估)回报。使用比数据集中最大单episode回报更高的回报来提示Decision Transformer,结果表明该模型有时具备外推能力。

  • 在原始DT论文中,通过替换RTG为专家轨迹的初始值,可以生成接近专家水平的轨迹。
  • 若逐步提高RTG(例如从低到高逐步微调),模型可能逐步适应更高的目标(类似于课程学习)。

其中:

  • 横轴代表目标回报,return-to-go,作为条件输入到transformer
  • 纵轴代表按照当前return-to-go作为条件输入执行策略得到的实际回报
  • 黄色竖线代表离线训练数据集中最大单episode回报作为return-to-go时的表现,代表 Behavior Cloning算法能力的天花板。
改进方向
  • RTG规范化
    在训练时对RTG进行标准化或分桶处理,增强模型对不同目标范围的适应能力。
  • 多目标训练
    引入多任务学习,使模型同时学习不同RTG目标下的策略。
  • 在线微调
    结合在线强化学习,在推断时根据实际回报动态调整RTG,避免目标与能力不匹配。

9.5. 总结

  • RTG的作用
    它既是目标条件信号,也是一种隐式的Prompt,通过数值化的剩余回报引导模型生成动作。
  • 诱导策略的有效性
    合理设置RTG(如在训练数据分布范围内逐步提高)可以生成更优轨迹,但需注意模型的泛化边界。
  • 与Prompt的区别
    RTG的动态更新和数值化特性使其更适用于强化学习的时序决策场景,而Prompt更偏向符号化语义引导。

通过以上机制,Decision Transformer在某种程度上实现了“通过目标设定(RTG)诱导策略生成”,这一思想与Prompt工程有异曲同工之妙,但需结合强化学习的特性进行针对性设计。

10. 总结

Decision Transformer通过将强化学习问题重新建模为条件序列生成任务,利用Transformer的全局建模能力和监督学习范式,显著简化了离线策略优化过程。其在长程决策和离线场景中表现突出,但对数据质量和回报设计高度敏感,是当前离线RL领域的重要突破方向。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

贝塔西塔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值