都在落地端到端!手撕代码,今天一起来梳理下UniAD的实现

点击下方卡片,关注“自动驾驶之心”公众号

戳我-> 领取自动驾驶近15个方向学习路线

今天自动驾驶之心为大家分享UniAD的核心代码解读,助力端到端落地!如果您有相关工作需要分享,请在文末联系我们!

自动驾驶课程学习与技术交流群事宜,也欢迎添加小助理微信AIDriver004做进一步咨询

>>点击进入→自动驾驶之心端到端自动驾驶技术交流群

编辑 | 自动驾驶之心

写在前面

UniAD是围绕查询设计的,它的感知、预测、规划任务都使用交叉注意力来将前置任务的查询转换为当前任务的查询。每个任务的查询以及这些查询的组合,都用长度为256的向量表征。最初的表征(BEV元素)通过多个Transformer Decoder,每个Decoder的输出有不同任务的监督数据,这些任务引导了中间表示(dim=256的向量)的演化。

c06678fa9d0f7fb871f26394f5c99da2.jpeg

在这个过程中,UniAD对表征做了很多组合和变换,引入了冗余的参数。(与之相比,英伟达的ParaDrive则直接从BEV并行训练多任务,取得了更好的效果。)中间表征在演化过程中有多个名字,如下表中形状中有x256的变量,作为每个模块的输入和输出。在开始关注每个模块的代码前,需要牢记下表中每个查询变量的含义和形状。

如果大家对端到端自动驾驶技术栈还不是很熟悉,可以翻一翻我们前面的端到端技术栈汇总!

对于感知部分,由于已经有很多文章对目标检测、跟踪和建图的论文和代码做过解读,因此我们不再重复。在这篇文章中,我们只对运动预测、占用预测和规划器进行解读。

90d27cfc1bf5df98406db07aedfbc6b3.jpeg

运动预测模块

0b6754e73cfbec841ae8c73ddc277e87.jpeg

运动预测模块是由3层Transformer构成的解码器,每个Transformer层有3个并行的交叉注意力模块,分别为对象-对象交互、对象-地图交互、以及对象-目标交互注意力模块。它们分别将上游跟踪模块的查询、建图模块的查询和鸟瞰特征作为键值,

323f6461c1a2dd97cb6bc6dc215b68ec.png

本模块的查询经过标准的多头自注意力和交叉注意力(键值为),输出临时查询,和它经过可变形注意力(键值为,参考点为)输出的临时查询串联拼接,再压缩为256维的,作为下一层的输入。的构造方式为,

d66ca4b9f42c766e238a05d34c182e7b.png

其中,是自车的锚点轨迹,是转换为场景坐标系的,它们是Ground Truth中统计频率最高的6条轨迹。是所有对象的起点,是所有对象的终点(初始时被设置为)。是这一层的输出,在下一层输入前和相加。从形式上看,它类似于一个残差,但比残差的构造复杂一点。

这个模块的结构可以在配置文件base_e2e.py中查看。我们忽略了参数部分,只关注模型的前向传播链。我们把写在模块代码的__init__中而不在配置文件中的部分也填充进去,构成完整的前向传播链:

# 配置文件中的前向传播链

motion_head=dict(
        type='MotionHead',
        
        transformerlayers=dict(
            type='MotionTransformerDecoder',

            transformerlayers=dict(
                type='MotionTransformerAttentionLayer',

                attn_cfgs=[
                    dict(
                        type='MotionDeformableAttention',
                    ),
                ],

                operation_order=('cross_attn', 'norm', 'ffn', 'norm')),
        ),
    ),

# 完整的前向传播链

motion_head=dict(
        type='MotionHead',
        
        transformerlayers=dict(
            // 解码器
            type='MotionTransformerDecoder',

            // 对象-目标交互
            intention_interaction_layers = IntentionInteraction()
            // 对象-对象交互
            track_agent_interaction_layers = nn.ModuleList(
            [TrackAgentInteraction() for i in range(self.num_layers)])
            // 对象-地图交互
            map_interaction_layers = nn.ModuleList(
            [MapInteraction() for i in range(self.num_layers)])
            // 对象-目标交互
            bev_interaction_layers = nn.ModuleList(
            [build_transformer_layer(transformerlayers) for i in range(self.num_layers)])

            // 一些对查询进行转换的MLP层
            static_dynamic_fuser = nn.Sequential(
                nn.Linear(self.embed_dims*2, self.embed_dims*2),
                nn.ReLU(),
                nn.Linear(self.embed_dims*2, self.embed_dims),
            )
            dynamic_embed_fuser = nn.Sequential(
                nn.Linear(self.embed_dims*3, self.embed_dims*2),
                nn.ReLU(),
                nn.Linear(self.embed_dims*2, self.embed_dims),
            )
            in_query_fuser = nn.Sequential(
                nn.Linear(self.embed_dims*2, self.embed_dims*2),
                nn.ReLU(),
                nn.Linear(self.embed_dims*2, self.embed_dims),
            )
            out_query_fuser = nn.Sequential(
                nn.Linear(self.embed_dims*4, self.embed_dims*2),
                nn.ReLU(),
                nn.Linear(self.embed_dims*2, self.embed_dims),
            )

            
            transformerlayers=dict(
                // 对象-目标交互
                type='MotionTransformerAttentionLayer',

                attn_cfgs=[
                    dict(
                        type='MotionDeformableAttention',
                    ),
                ],

                operation_order=('cross_attn', 'norm', 'ffn', 'norm')
            ),
        ),
    ),

我们按照UniAD、MotionHead、MotionTransformerDecoder的顺序看前向传播链的forward函数。其中:

  1. UniAD的forward只是简单地串联起了五个模块

  2. MotionHead相当于MotionTransformerDecoder的预处理和后处理,它做了锚点和轨迹首尾点的嵌入,作为MotionTransformerDecoder的输入,之后又将MotionTransformerDecoder的输出查询转换为轨迹点列

  3. MotionTransformerDecoder是最核心的部分,它将锚点和轨迹的嵌入相加构造输入查询,通过三个交叉注意力模块吸收上游模块的查询的信息,再将它们的输出串联并压缩,得到输出查询。

我们对代码中的关键部分进行注释,并省略相对次要的部分。

MotionHead:

这段代码主要实现了以下功能:

  1. 构造并归一化代理级别和场景级别的锚点。

  2. 使用嵌入层将锚点的位置信息转换为嵌入向量。

  3. 通过MotionFormer模型进行前向传播,获取中间状态和参考轨迹。

  4. 计算每个级别的轨迹分数和轨迹,并应用双变量高斯激活。

  5. 构造输出字典,包含轨迹分数、预测轨迹、轨迹查询的有效性掩码等信息,并返回

class MotionHead(BaseMotionHead):

    def forward(self, 
                bev_embed, 
                track_query, 
                lane_query, 
                lane_query_pos, 
                track_bbox_results):
        """
        该函数执行模型的前向传播,用于基于鸟瞰图(BEV)嵌入、轨迹查询、车道查询和轨迹边界框结果进行运动预测。

        参数:

        bev_embed (torch.Tensor):形状为 (h*w, B, D) 的张量,表示鸟瞰图嵌入。
        track_query (torch.Tensor):形状为 (B, num_dec, A_track, D) 的张量,表示轨迹查询。
        lane_query (torch.Tensor):形状为 (N, M_thing, D) 的张量,表示车道查询。
        lane_query_pos (torch.Tensor):形状为 (N, M_thing, D) 的张量,表示车道查询的位置。
        track_bbox_results (List[torch.Tensor]):包含批次中每个图像的跟踪边界框结果的张量列表。
        返回值:

        dict:包含以下键和值的字典:
        'all_traj_scores':形状为 (num_levels, B, A_track, num_points) 的张量,包含每个级别的轨迹分数。
        'all_traj_preds':形状为 (num_levels, B, A_track, num_points, num_future_steps, 2) 的张量,包含每个级别的预测轨迹。
        'valid_traj_masks':形状为 (B, A_track) 的张量,指示轨迹掩码的有效性。
        'traj_query':包含轨迹查询中间状态的张量。
        'track_query':包含输入轨迹查询的张量。s
        'track_query_pos':包含轨迹查询位置嵌入的张量。
        """
        
        ...
        

        # 构造代理级别/场景级别的查询位置嵌入  
        # (num_groups, num_anchor, 12, 2)  
        # 以融入不同组和坐标的信息,并嵌入方向和位置信息 
        agent_level_anchors = self.kmeans_anchors.to(dtype).to(device).view(num_groups, self.num_anchor, self.predict_steps, 2).detach()
        scene_level_ego_anchors = anchor_coordinate_transform(agent_level_anchors, track_bbox_results, with_translation_transform=True)  # B, A, G, P ,12 ,2
        scene_level_offset_anchors = anchor_coordinate_transform(agent_level_anchors, track_bbox_results, with_translation_transform=False)  # B, A, G, P ,12 ,2

        # 对锚点进行归一化
        agent_level_norm = norm_points(agent_level_anchors, self.pc_range)
        scene_level_ego_norm = norm_points(scene_level_ego_anchors, self.pc_range)
        scene_level_offset_norm = norm_points(scene_level_offset_anchors, self.pc_range)

        # 仅使用锚点的最后一个点
        agent_level_embedding = self.agent_level_embedding_layer(
            pos2posemb2d(agent_level_norm[..., -1, :]))  # G, P, D
        scene_level_ego_embedding = self.scene_level_ego_embedding_layer(
            pos2posemb2d(scene_level_ego_norm[..., -1, :]))  # B, A, G, P , D
        scene_level_offset_embedding = self.scene_level_offset_embedding_layer(
            pos2posemb2d(scene_level_offset_norm[..., -1, :]))  # B, A, G, P , D
        
        ...

        outputs_traj_scores = []
        outputs_trajs = []

        # 通过MotionFormer模型进行前向传播
        # 输入各种查询、位置、边界框结果、BEV嵌入、初始参考轨迹等  
        # 以及锚点嵌入和锚点位置嵌入层  
        inter_states, inter_references = self.motionformer(
            track_query,  # B, A_track, D
            lane_query,  # B, M, D
            track_query_pos=track_query_pos,
            lane_query_pos=lane_query_pos,
            track_bbox_results=track_bbox_results,
            bev_embed=bev_embed,
            reference_trajs=init_reference,
            traj_reg_branches=self.traj_reg_branches,
            traj_cls_branches=self.traj_cls_branches,
            # anchor embeddings 
            agent_level_embedding=agent_level_embedding,
            scene_level_ego_embedding=scene_level_ego_embedding,
            scene_level_offset_embedding=scene_level_offset_embedding,
            learnable_embed=learnable_embed,
            # anchor positional embeddings layers
            agent_level_embedding_layer=self.agent_level_embedding_layer,
            scene_level_ego_embedding_layer=self.scene_level_ego_embedding_layer,
            scene_level_offset_embedding_layer=self.scene_level_offset_embedding_layer,
            spatial_shapes=torch.tensor(
                [[self.bev_h, self.bev_w]], device=device),
            level_start_index=torch.tensor([0], device=device))

        # 遍历每个级别,计算轨迹分数和轨迹
        for lvl in range(inter_states.shape[0]):
            outputs_class = self.traj_cls_branches[lvl](inter_states[lvl])
            tmp = self.traj_reg_branches[lvl](inter_states[lvl])
            tmp = self.unflatten_traj(tmp)
            
            # 使用累积和技巧来获取轨迹
            tmp[..., :2] = torch.cumsum(tmp[..., :2], dim=3)

            outputs_class = self.log_softmax(outputs_class.squeeze(3))
            outputs_traj_scores.append(outputs_class)

            # 对每个批次应用双变量高斯激活
            for bs in range(tmp.shape[0]):
                tmp[bs] = bivariate_gaussian_activation(tmp[bs])
            outputs_trajs.append(tmp)

        # 堆叠并输出轨迹分数和轨迹
        outputs_traj_scores = torch.stack(outputs_traj_scores)
        outputs_trajs = torch.stack(outputs_trajs)

        # 获取轨迹查询的有效性掩码
        B, A_track, D = track_query.shape
        valid_traj_masks = track_query.new_ones((B, A_track)) > 0

        # 构造输出字典
        outs = {
            'all_traj_scores': outputs_traj_scores,
            'all_traj_preds': outputs_trajs,
            'valid_traj_masks': valid_traj_masks,
            'traj_query': inter_states,
            'track_query': track_query,
            'track_query_pos': track_query_pos,
        }

        return outs

MotionTransformerDecoder:

这段代码融合了静态意图、动态意图、代理之间的交互、代理与地图的交互以及代理与目标的交互,最终生成了融合了多种信息的查询嵌入query_embed。这里意图(intention)代表目标位置。

class MotionTransformerDecoder(BaseModule):

    def forward(self,
                track_query,
                lane_query,
                track_query_pos=None,
                lane_query_pos=None,
                track_bbox_results=None,
                bev_embed=None,
                reference_trajs=None,
                traj_reg_branches=None,
                agent_level_embedding=None,
                scene_level_ego_embedding=None,
                scene_level_offset_embedding=None,
                learnable_embed=None,
                agent_level_embedding_layer=None,
                scene_level_ego_embedding_layer=None,
                scene_level_offset_embedding_layer=None,
                **kwargs):
        """Forward function for `MotionTransformerDecoder`.
        Args:
            agent_query (B, A, D):代理查询,其中 B 表示批次大小,A 表示代理(agent)的数量,D 表示特征维度。
            map_query (B, M, D):地图查询,其中 M 表示地图中的对象数量。
            map_query_pos (B, G, D):地图查询位置。
            static_intention_embed (B, A, P, D):静态意图嵌入,其中 P 表示意图的数量。代表每个代理的静态或固定的意图。
            offset_query_embed (B, A, P, D):偏移查询嵌入,与意图的偏移或变化有关。
            global_intention_embed (B, A, P, D):全局意图嵌入,代表每个代理的全局或整体意图。
            learnable_intention_embed (B, A, P, D):可学习的意图嵌入,是模型在训练过程中学习的意图表示。
            det_query_pos (B, A, D):检测查询位置,代表与检测任务相关的位置信息。
        Returns:
            None
        """
        intermediate = []    # 用于存储中间输出的列表  
        intermediate_reference_trajs = []    # 用于存储中间参考轨迹的列表

        # 对输入进行广播和扩展,以匹配所需的形状
        B, _, P, D = agent_level_embedding.shape
        track_query_bc = track_query.unsqueeze(2).expand(-1, -1, P, -1)  # (B, A, P, D)
        track_query_pos_bc = track_query_pos.unsqueeze(2).expand(-1, -1, P, -1)  # (B, A, P, D)

        # 计算静态意图嵌入,它在所有层中都是不变的
        agent_level_embedding = self.intention_interaction_layers(agent_level_embedding)
        static_intention_embed = agent_level_embedding + scene_level_offset_embedding + learnable_embed
        reference_trajs_input = reference_trajs.unsqueeze(4).detach()

        # 初始化查询嵌入,其形状与静态意图嵌入相同  
        query_embed = torch.zeros_like(static_intention_embed)
        for lid in range(self.num_layers):
            # 融合动态意图嵌入  
            # 动态意图嵌入是前一层的输出,初始化为锚点嵌入(anchor embedding)
            dynamic_query_embed = self.dynamic_embed_fuser(torch.cat(
                [agent_level_embedding, scene_level_offset_embedding, scene_level_ego_embedding], dim=-1))
            
            # 融合静态和动态意图嵌入  
            query_embed_intention = self.static_dynamic_fuser(torch.cat(
                [static_intention_embed, dynamic_query_embed], dim=-1))  # (B, A, P, D)
            
            # 将意图嵌入与查询嵌入融合
            query_embed = self.in_query_fuser(torch.cat([query_embed, query_embed_intention], dim=-1))
            
            # 代理之间的交互
            track_query_embed = self.track_agent_interaction_layers[lid](
                query_embed, track_query, query_pos=track_query_pos_bc, key_pos=track_query_pos)
            
            # 代理与地图之间的交互
            map_query_embed = self.map_interaction_layers[lid](
                query_embed, lane_query, query_pos=track_query_pos_bc, key_pos=lane_query_pos)
            
            # 代理与目标(BEV,即鸟瞰图)之间的交互,使用可变形Transformer实现  
            # implemented with deformable transformer
            bev_query_embed = self.bev_interaction_layers[lid](
                query_embed,
                value=bev_embed,
                query_pos=track_query_pos_bc,
                bbox_results=track_bbox_results,
                reference_trajs=reference_trajs_input,
                **kwargs)
            
            # 融合来自不同交互层的嵌入
            query_embed = [track_query_embed, map_query_embed, bev_query_embed, track_query_bc+track_query_pos_bc]
            query_embed = torch.cat(query_embed, dim=-1)
            query_embed = self.out_query_fuser(query_embed)
            ...

占用预测:

5a12b80dc8851480ff4b4bc2e1fde8a7.jpeg

占用预测模块由5层Transformer构成,每层有一个自注意力和一个交叉注意力模块:

692d7cd023dc29402070c4150e741284.png

自注意力层的输入是BEV特征的1/8下采样,是由跟踪和运动预测模块的输出查询串联后压缩构造,由和相乘得到。

占用预测模块的输出由解码后的压缩BEV和由掩码生成的占用特征相乘而得:

这部分的代码十分简单,因为DetrTransformerDecoder和DetrTransformerDecoderLayer是mmdet3d实现的,这里只是简单调用这个解码器和解码器层。这个模块的配置和前向传播如下,它主要实现了占用预测的任务,通过Transformer解码器对输入的特征图进行处理,生成未来的状态、掩码预测和占用逻辑。

occ_head=dict(
        type='OccHead',

        # Transformer
        transformer_decoder=dict(
            type='DetrTransformerDecoder',
            
            num_layers=5,
            transformerlayers=dict(
                type='DetrTransformerDecoderLayer',
                attn_cfgs=dict(
                    type='MultiheadAttention',
                    ),

                operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                                 'ffn', 'norm')),
            ),

    ),



class OccHead(BaseModule):
    def forward(self, x, ins_query):
        # 重新排列输入特征图 
        base_state = rearrange(x, '(h w) b d -> b d h w', h=self.bev_size[0])

        # 对特征图进行采样 
        base_state = self.bev_sampler(base_state)
        # 对特征图进行轻量级投影
        base_state = self.bev_light_proj(base_state)
        # 对特征图进行下采样
        base_state = self.base_downscale(base_state)
        base_ins_query = ins_query

        # 初始化最后的状态和查询
        last_state = base_state
        last_ins_query = base_ins_query
        future_states = []    # 存储未来的状态
        mask_preds = []        # 存储掩码预测
        temporal_query = []    # 存储时间查询
        temporal_embed_for_mask_attn = []    # 存储用于掩码注意力的时间嵌入

        # 计算每个块的Transformer层数
        n_trans_layer_each_block = self.num_trans_layers // self.n_future_blocks
        assert n_trans_layer_each_block >= 1
        
        # 遍历未来的块
        for i in range(self.n_future_blocks):
            # 下采样
            cur_state = self.downscale_convs[i](last_state)  # /4 -> /8

            # 注意力
            # 时间感知的ins_query
            cur_ins_query = self.temporal_mlps[i](last_ins_query)  # [b, q, d]
            temporal_query.append(cur_ins_query)

            # 生成注意力掩码 
            attn_mask, mask_pred, cur_ins_emb_for_mask_attn = self.get_attn_mask(cur_state, cur_ins_query)
            attn_masks = [None, attn_mask] 

            mask_preds.append(mask_pred)  # /1
            temporal_embed_for_mask_attn.append(cur_ins_emb_for_mask_attn)

            # 重新排列状态和查询
            cur_state = rearrange(cur_state, 'b c h w -> (h w) b c')
            cur_ins_query = rearrange(cur_ins_query, 'b q c -> q b c')

            # 遍历Transformer层
            for j in range(n_trans_layer_each_block):
                trans_layer_ind = i * n_trans_layer_each_block + j
                trans_layer = self.transformer_decoder.layers[trans_layer_ind]
                cur_state = trans_layer(
                    query=cur_state,  # [h'*w', b, c]
                    key=cur_ins_query,  # [nq, b, c]
                    value=cur_ins_query,  # [nq, b, c]
                    query_pos=None,  
                    key_pos=None,
                    attn_masks=attn_masks,
                    query_key_padding_mask=None,
                    key_padding_mask=None
                )  # out size: [h'*w', b, c]

            # 重新排列状态
            cur_state = rearrange(cur_state, '(h w) b c -> b c h w', h=self.bev_size[0]//8)
            
            # 上采样到/4
            cur_state = self.upsample_adds[i](cur_state, last_state)

            # 输出
            future_states.append(cur_state)  # [b, d, h/4, w/4]
            last_state = cur_state

        # 堆叠未来的状态、时间查询、掩码预测和查询嵌入
        future_states = torch.stack(future_states, dim=1)  # [b, t, d, h/4, w/4]
        temporal_query = torch.stack(temporal_query, dim=1)  # [b, t, q, d]
        mask_preds = torch.stack(mask_preds, dim=2)  # [b, q, t, h, w]
        ins_query = torch.stack(temporal_embed_for_mask_attn, dim=1)  # [b, t, q, d]

        # 将未来状态解码到更大的分辨率
        future_states = self.dense_decoder(future_states)
        ins_occ_query = self.query_to_occ_feat(ins_query)    # [b, t, q, query_out_dim]
        
        # 生成最终输出
        ins_occ_logits = torch.einsum("btqc,btchw->bqthw", ins_occ_query, future_states)
        
        return mask_preds, ins_occ_logits

规划模块

43884dad5b3af0ff17980cfa4d536972.jpeg

规划模块是3层解码器,在图中由下到上依次为:

  1. 构造查询Q:由跟踪和运动预测模块的输出查询和导航信息的嵌入相加,经过MLP,选择概率最大的单个轨迹(从6个轨迹中),最后加上位置编码

  2. 通过N层解码器融合BEV信息,得到轨迹

  3. 使用占用预测做碰撞优化

代码也比较简单,这个模块的配置和前向传播如下。它实现了基于BEV特征嵌入、占用掩码、驾驶命令等输入,生成SDC(自动驾驶车辆)轨迹的过程。

planning_head=dict(
        type='PlanningHeadSingleMode',
        embed_dims=256,

        planning_steps=planning_steps,
        loss_planning=dict(type='PlanningLoss'),
        loss_collision=[dict(type='CollisionLoss', delta=0.0, weight=2.5),
                        dict(type='CollisionLoss', delta=0.5, weight=1.0),
                        dict(type='CollisionLoss', delta=1.0, weight=0.25)],
        use_col_optim=use_col_optim,
        planning_eval=True,
        with_adapter=True,
    ),


class PlanningHeadSingleMode(nn.Module):  
    def forward(self,   
                bev_embed,  # BEV(鸟瞰图)特征嵌入  
                occ_mask,   # 占用实例掩码  
                bev_pos,    # BEV位置  
                sdc_traj_query, # SDC轨迹查询  
                sdc_track_query, # SDC轨迹追踪查询  
                command):   # 驾驶命令  
        """  
        前向传播过程。  
  
        参数:  
            bev_embed (torch.Tensor): 鸟瞰图特征嵌入。  
            occ_mask (torch.Tensor): 占用实例掩码。  
            bev_pos (torch.Tensor): BEV位置。  
            sdc_traj_query (torch.Tensor): SDC轨迹查询。  
            sdc_track_query (torch.Tensor): SDC轨迹追踪查询。  
            command (int): 驾驶命令。  
  
        返回:  
            dict: 包含SDC轨迹和所有SDC轨迹的字典。  
        """  
  
        ...  
  
        # 根据驾驶命令获取导航嵌入  
        navi_embed = self.navi_embed.weight[command]  
        navi_embed = navi_embed[None].expand(-1, P, -1)  
        # 融合SDC轨迹查询、SDC轨迹追踪查询和导航嵌入  
        plan_query = torch.cat([sdc_traj_query, sdc_track_query, navi_embed], dim=-1)  
  
        # 使用多层感知机(MLP)融合查询,并取最大值  
        plan_query = self.mlp_fuser(plan_query).max(1, keepdim=True)[0]  
        # 重排plan_query的形状  
        plan_query = rearrange(plan_query, 'b p c -> p b c')  
          
        # 重排bev_pos的形状  
        bev_pos = rearrange(bev_pos, 'b c h w -> (h w) b c')  
        bev_feat = bev_embed + bev_pos  
          
        # 插件适配器  
        if self.with_adapter:  
            bev_feat = rearrange(bev_feat, '(h w) b c -> b c h w', h=self.bev_h, w=self.bev_w)  
            bev_feat = bev_feat + self.bev_adapter(bev_feat)  # 残差连接  
            bev_feat = rearrange(bev_feat, 'b c h w -> (h w) b c')  
          
        # 添加位置嵌入  
        pos_embed = self.pos_embed.weight  
        plan_query = plan_query + pos_embed[None]  
          
        # 使用注意力模块处理plan_query和bev_feat  
        plan_query = self.attn_module(plan_query, bev_feat)  
          
        # 回归分支,生成SDC轨迹  
        sdc_traj_all = self.reg_branch(plan_query).view((-1, self.planning_steps, 2))  
        # 累计求和,生成轨迹点  
        sdc_traj_all[...,:2] = torch.cumsum(sdc_traj_all[...,:2], dim=1)  
        # 对第一条轨迹应用双变量高斯激活  
        sdc_traj_all[0] = bivariate_gaussian_activation(sdc_traj_all[0])  
        # 如果使用碰撞优化且非训练模式,进行后处理  
        if self.use_col_optim and not self.training:  
            assert occ_mask is not None  
            sdc_traj_all = self.collision_optimization(sdc_traj_all, occ_mask)  
          
        # 返回SDC轨迹和所有SDC轨迹  
        return dict(  
            sdc_traj=sdc_traj_all,  
            sdc_traj_all=sdc_traj_all,  
        )

写在最后的话:

UniAD在自动驾驶的每个子任务中使用交叉注意力,显式指定了信息流向,且只能利用nuScenes类的封闭数据集。与之相对的,基于多模态基础模型的自动驾驶架构只使用自注意力层,将视觉信息映射为tokens,和多模态prompt的tokens一起输入自注意力层,多模态prompt隐式指定了信息流向,基础模型也已经用开放数据集进行了预训练。这是一种与UniAD互补的思路。

投稿作者为『自动驾驶之心知识星球』特邀嘉宾,欢迎加入交流!重磅,自动驾驶之心科研论文辅导来啦,申博、CCF系列、SCI、EI、毕业论文、比赛辅导等多个方向,欢迎联系我们!

e0a0dc600aa682b92929b0e0dfd80ca3.jpeg

① 全网独家视频课程

BEV感知、BEV模型部署、BEV目标跟踪、毫米波雷达视觉融合多传感器标定多传感器融合多模态3D目标检测车道线检测轨迹预测在线高精地图世界模型点云3D目标检测目标跟踪Occupancy、cuda与TensorRT模型部署大模型与自动驾驶Nerf语义分割自动驾驶仿真、传感器部署、决策规划、轨迹预测等多个方向学习视频(扫码即可学习

a8aea76a64e3a3b620c391506abe841f.png 网页端官网:www.zdjszx.com

② 国内首个自动驾驶学习社区

国内最大最专业,近3000人的交流社区,已得到大多数自动驾驶公司的认可!涉及30+自动驾驶技术栈学习路线,从0到一带你入门自动驾驶感知2D/3D检测、语义分割、车道线、BEV感知、Occupancy、多传感器融合、多传感器标定、目标跟踪)、自动驾驶定位建图SLAM、高精地图、局部在线地图)、自动驾驶规划控制/轨迹预测等领域技术方案大模型、端到端等,更有行业动态和岗位发布!欢迎扫描下方二维码,加入自动驾驶之心知识星球,这是一个真正有干货的地方,与领域大佬交流入门、学习、工作、跳槽上的各类难题,日常分享论文+代码+视频

902dbc14964008afa30ccf68c9dced01.png

③【自动驾驶之心】技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦感知、定位、融合、规控、标定、端到端、仿真、产品经理、自动驾驶开发、自动标注与数据闭环多个方向,目前近60+技术交流群,欢迎加入!扫码添加汽车人助理微信邀请入群,备注:学校/公司+方向+昵称(快速入群方式)

0712f46bc399259d8b53ab6caf1c15d7.jpeg

④【自动驾驶之心】全平台矩阵

303d2ab880c11824e0c85fdc2e56cf0c.png

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值