(7-6)行为预测算法:基于Trajectron++模型的行为预测系统

7.6  基于Trajectron++模型的行为预测

Trajectron++是一个用于多目标轨迹预测和规划的深度学习模型,旨在应对自动驾驶和机器人等领域中的挑战,其中多个移动目标需要被准确地预测其未来运动轨迹,以便做出智能决策。

7.6.1  Trajectron++模型的特点

Trajectron++模型的主要特点和功能如下所示。

  1. 多目标轨迹预测:Trajectron++ 的核心任务是预测多个移动目标的未来运动轨迹,这对于自动驾驶车辆、机器人等在复杂交通场景中的行为规划至关重要。
  2. 深度学习架构:Trajectron++模型采用深度学习技术,包括循环神经网络(RNN)和卷积神经网络(CNN),以便有效地处理时间序列和空间信息,从而更好地捕捉目标的运动模式。
  3. 多智能体建模:Trajectron++ 考虑了多个移动目标之间的相互作用和关系。这有助于更准确地预测每个目标的轨迹,因为它们的运动可能受到彼此的影响。
  4. 生成式模型:Trajectron++ 是一个生成式模型,它可以生成可能的未来轨迹的概率分布。这使得它能够更灵活地处理不确定性,对于智能决策非常重要。
  5. 实时性能:Trajectron++ 被设计成具有实时性能,以便在实际应用中能够及时地做出决策。

在实际应用中,Trajectron++模型被主要用在自动驾驶领域,能够通过对轨迹集合建模来捕捉不确定性,用于车辆和行人的轨迹预测,为实体感知和决策提供了强大的支持。

7.6.2  基于Trajectron++模型的行为预测系统

在本项目中,使用 PyTorch Lightning 和 Lyft 提供的 l5kit 工具包实现了一个灵活的数据加载器,支持多智能体训练。展示了配置和使用不同的数据集、rasterizer,并提供了对训练数据批次结构的详细解析,为进一步的模型训练和实验奠定了基础。

实例2-8Trajectron++行为预测系统codes/2/lyft-multi-agent.ipynb

本项目使用的是Lyft 公司提供的自动驾驶车辆运动预测数据集(Lyft Motion Prediction Autonomous Vehicles),这是一个开源数据集。该数据集的目标是帮助研究人员和开发者训练和评估自动驾驶车辆在城市环境中的运动预测能力。数据集Lyft Motion Prediction for Autonomous Vehicles的主要特点和内容如下所示。

  1. 场景和环境:数据集提供了在城市环境中采集的大量传感器数据,涵盖了各种驾驶场景,包括道路、交叉口、人行道等。这使得研究人员能够测试和优化自动驾驶系统在不同复杂环境中的性能。
  2. 传感器数据:数据集包含了来自各种传感器的信息,如激光雷达、摄像头、雷达等。这些传感器数据为车辆周围的环境提供了高分辨率的感知信息。
  3. 运动轨迹和预测:Lyft 数据集中包含了车辆的历史运动轨迹数据,并提供了对未来运动轨迹的预测。这使得研究人员可以训练和评估模型在预测其他车辆或行人行为时的准确性。
  4. 地图信息:数据集可能包括高精度地图信息,以帮助自动驾驶车辆更好地理解和导航城市环境。
  5. 用于研究的挑战性问题:Lyft 数据集通常包含一些挑战性的问题,以促使研究人员开发创新性的算法和模型。这有助于推动自动驾驶技术的发展。
  6. 使用工具包:Lyft 提供了与数据集一起使用的工具包,如 l5kit,以便更轻松地处理和分析数据。

通过使用 Lyft 的自动驾驶车辆运动预测数据集,研究人员可以进行各种实验和测试,以提高自动驾驶系统在复杂城市交通中的性能。请查阅 Lyft 公司的官方文档或数据集页面,以获取更详细和最新的信息。实例文件lyft-multi-agent.ipynb的具体实现流程如下所示。

(1)导入必要的库和模块,设置全局变量,以及检测当前是否在 Kaggle 环境中。其中,l5kit 是 Lyft 公司提供的用于处理自动驾驶车辆运动预测数据集的工具包,而代码中的变量和模块则为后续的数据处理和模型训练做准备。

import bisect
import os
from copy import deepcopy
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pytorch_lightning as pl
from l5kit.data import ChunkedDataset, LocalDataManager
from l5kit.dataset import AgentDataset
from l5kit.rasterization import StubRasterizer, build_rasterizer
from torch.utils.data import DataLoader, Dataset, Subset

is_kaggle = os.path.isdir("")

(2)定义一个名为 CONFIG_DATA 的字典,其中包含了有关模型、光栅化参数以及数据加载器的配置信息。具体而言,它包括了模型的架构(resnet34)、历史和未来运动轨迹的帧数和时间步长、光栅化的相关参数、训练和验证数据加载器的配置等。这个配置字典被用于指定模型和数据的相关参数,以便在后续的训练和评估中使用。

# 配置数据字典,包括模型、光栅化参数以及数据加载器的相关配置信息
CONFIG_DATA = {
    "format_version": 4,
    "model_params": {
        "model_architecture": "resnet34",  # 模型架构
        "history_num_frames": 10,  # 历史运动轨迹的帧数
        "history_step_size": 1,  # 历史轨迹的时间步长
        "history_delta_time": 0.1,  # 历史轨迹的时间步长
        "future_num_frames": 50,  # 未来运动轨迹的帧数
        "future_step_size": 1,  # 未来轨迹的时间步长
        "future_delta_time": 0.1,  # 未来轨迹的时间步长
    },
    "raster_params": {
        "raster_size": [256, 256],  # 光栅化图像大小
        "pixel_size": [0.5, 0.5],  # 像素大小
        "ego_center": [0.25, 0.5],  # 智能驾驶汽车中心相对位置
        "map_type": "py_semantic",  # 地图类型
        "satellite_map_key": "aerial_map/aerial_map.png",  # 卫星地图键
        "semantic_map_key": "semantic_map/semantic_map.pb",  # 语义地图键
        "dataset_meta_key": "meta.json",  # 数据集元数据键
        "filter_agents_threshold": 0.5,  # 过滤智能体的阈值
        "disable_traffic_light_faces": False,  # 是否禁用交通灯的人脸
    },
    "train_dataloader": {
        "key": "scenes/sample.zarr",  # 训练数据集键
        "batch_size": 24,  # 批量大小
        "shuffle": True,  # 是否打乱数据
        "num_workers": 0,  # 数据加载器的工作进程数
    },
    "val_dataloader": {
        "key": "scenes/validate.zarr",  # 验证数据集键
        "batch_size": 24,  # 批量大小
        "shuffle": False,  # 是否打乱数据
        "num_workers": 4,  # 数据加载器的工作进程数
    },
    "test_dataloader": {
        "key": "scenes/test.zarr",  # 测试数据集键
        "batch_size": 24,  # 批量大小
        "shuffle": False,  # 是否打乱数据
        "num_workers": 4,  # 数据加载器的工作进程数
    },
    "train_params": {
        "max_num_steps": 400,  # 最大训练步数
        "eval_every_n_steps": 50,  # 每隔多少步进行一次评估
    },
}

(3)创建一个名为 MultiAgentDataset 的 PyTorch 数据集类,用于组合两个不同的 AgentDataset 数据集以创建一个新的多智能体数据集。该数据集用于训练神经网络等模型,以预测多个智能体(例如车辆)的运动轨迹。

from typing import List, Dict, Any, Tuple

class MultiAgentDataset(Dataset):
    def __init__(
        self,
        rast_only_agent_dataset: AgentDataset,
        history_agent_dataset: AgentDataset,
        num_neighbors: int = 10,
    ):
        super().__init__()
        self.rast_only_agent_dataset = rast_only_agent_dataset  # 光栅信息数据集
        self.history_agent_dataset = history_agent_dataset  # 历史信息数据集
        self.num_neighbors = num_neighbors  # 其他智能体数量

    def __len__(self) -> int:
        return len(self.rast_only_agent_dataset)  # 返回数据集长度

    def get_others_dict(
        self, index: int, ego_dict: Dict[str, Any]
    ) -> Tuple[List[Dict[str, Any]], int]:
        agent_index = self.rast_only_agent_dataset.agents_indices[index]  # 获取智能体索引
        frame_index = bisect.bisect_right(
            self.rast_only_agent_dataset.cumulative_sizes_agents, agent_index
        )  # 查找所属帧索引
        frame_indices = self.rast_only_agent_dataset.get_frame_indices(frame_index)
        assert len(frame_indices) >= 1, frame_indices
        frame_indices = frame_indices[frame_indices != index]  # 剔除当前智能体索引

        others_dict = []
        # 当前帧中 AV 的质心在世界参考系中的坐标,单位为米
        for idx, agent in zip(
            frame_indices,
            Subset(self.history_agent_dataset, frame_indices),
        ):
            agent["dataset_idx"] = idx
            agent["dist_to_ego"] = np.linalg.norm(
                agent["centroid"] - ego_dict["centroid"], ord=2
            )  # 计算到当前智能体的距离
            # 在未来版本中,可以通过智能体和智能驾驶汽车的转换矩阵将历史位置转换为归一化版本
            # 并获得标准化的版本
            del agent["image"]
            others_dict.append(agent)

        others_dict = sorted(others_dict, key=itemgetter("dist_to_ego"))
        others_dict = others_dict[: self.num_neighbors]
        others_len = len(others_dict)

        # 必须填充,因为 torch 不支持不规则张量
        # https://github.com/pytorch/pytorch/issues/25032
        length_to_pad = self.num_neighbors - others_len
        pad_item = deepcopy(ego_dict)
        pad_item["dataset_idx"] = index
        pad_item["dist_to_ego"] = np.nan  # 设置为 nan 以防止误用
        del pad_item["image"]
        return (others_dict + [pad_item] * length_to_pad, others_len)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        rast_dict = self.rast_only_agent_dataset[index]
        ego_dict = self.history_agent_dataset[index]
        others_dict, others_len = self.get_others_dict(index, ego_dict)
        ego_dict["image"] = rast_dict["image"]
        return {
            "ego_dict": ego_dict,
            "others_dict": others_dict,
            "others_len": others_len,
        }

(4)定义一个 PyTorch Lightning 的数据模块 LyftAgentDataModule,用于管理 Lyft 自动驾驶车辆运动预测数据集的加载和处理。通过配置信息,创建了训练、验证和测试数据加载器,实现了数据的统一管理和准备,方便在 PyTorch Lightning 中进行模型训练和评估工作。

class LyftAgentDataModule(pl.LightningDataModule):
    def __init__(self, cfg: Dict = CONFIG_DATA, data_root: str = data_root):
        super().__init__()
        self.cfg = cfg
        self.dm = LocalDataManager(data_root)
        self.rast = build_rasterizer(self.cfg, self.dm)

    def chunked_dataset(self, key: str):
        dl_cfg = self.cfg[key]
        dataset_path = self.dm.require(dl_cfg["key"])
        zarr_dataset = ChunkedDataset(dataset_path)
        zarr_dataset.open()
        return zarr_dataset

    def get_dataloader_by_key(
        self, key: str, mask: Optional[np.ndarray] = None
    ) -> DataLoader:
        dl_cfg = self.cfg[key]
        zarr_dataset = self.chunked_dataset(key)
        agent_dataset = AgentDataset(
            self.cfg, zarr_dataset, self.rast, agents_mask=mask
        )
        return DataLoader(
            agent_dataset,
            shuffle=dl_cfg["shuffle"],
            batch_size=dl_cfg["batch_size"],
            num_workers=dl_cfg["num_workers"],
            pin_memory=True,
        )

    def train_dataloader(self):
        key = "train_dataloader"
        return self.get_dataloader_by_key(key)

    def val_dataloader(self):
        key = "val_dataloader"
        return self.get_dataloader_by_key(key)

    def test_dataloader(self):
        key = "test_dataloader"
        test_mask = np.load(f"{data_root}/scenes/mask.npz")["arr_0"]
        return self.get_dataloader_by_key(key, mask=test_mask)

上述代码的实现流程如下:

  1. 首先,创建类LyftAgentDataModule,继承自 PyTorch Lightning 的 类LightningDataModule,用于管理 Lyft 自动驾驶车辆运动预测数据集的加载和数据处理。
  2. 然后,在初始化函数 __init__ 中,配置信息 cfg 和数据根目录 data_root 被传递并存储。LocalDataManager 被用于管理本地数据路径,而 build_rasterizer 函数被用于创建光栅化器。
  3. 接着,chunked_dataset 函数用于获取指定键(key)对应的数据集。该函数根据配置信息中的键获取数据集路径,使用 ChunkedDataset 打开并返回。
  4. 接下来,get_dataloader_by_key 函数通过指定的键获取对应的数据加载器。首先,使用 chunked_dataset 函数获取数据集,然后使用 AgentDataset 类构建代理数据集,传入配置信息、数据集、光栅化器和智能体的掩码(如果有的话)。最后,使用 PyTorch 的 DataLoader 创建数据加载器,配置包括是否打乱数据、批量大小、工作进程数等。
  5. 最后,train_dataloader、val_dataloader 和 test_dataloader 函数分别用于获取训练、验证和测试数据加载器。这些函数调用了 get_dataloader_by_key,传入相应的键,同时在测试数据加载器中还传入了智能体的掩码。这样,整个 LyftAgentDataModule 类提供了训练、验证和测试数据加载器的统一接口,方便在 PyTorch Lightning 中进行训练和评估。

(5)定义一个名为 MultiAgentDataModule 的 PyTorch Lightning 数据模块,继承自 LyftAgentDataModule。通过对智能体数据集进行定制化配置,创建了一个用于多智能体训练的数据加载器。

from pprint import pprint
for item in datamodule.train_dataloader():
    pprint(item.keys())
    print('ego_dict keys')
    pprint(item['ego_dict'].keys())
    pprint(len(item['others_dict']))
    pprint(item['others_dict'][0].keys())
    pprint(item['others_len'])
    break

对上述代码的具体说明如下所示:

  1. 首先,定义类MultiAgentDataModule,继承自 LyftAgentDataModule。在初始化函数中调用父类的初始化,同时创建了一个用于调试的 StubRasterizer。
  2. 然后,通过 get_dataloader_by_key 函数获取训练数据加载器。该函数使用 AgentDataset 类构建了两个数据集:一个只包含光栅信息的智能体数据集和一个使用 StubRasterizer 的包含历史信息的智能体数据集。
  3. 接着,通过创建 MultiAgentDataset 实例,将上述两个数据集传递给 PyTorch 的 DataLoader,配置了是否打乱数据、批量大小、工作进程数等参数,以便用于模型的训练。
  4. 最后,通过创建 MultiAgentDataModule 实例,完成了整个数据模块的配置和准备,方便在 PyTorch Lightning 中进行多智能体训练。

(6)通过训练数据加载器获取了一个批次的数据,并使用函数 pprint打印输出了该批次数据的结构和内容信息。首先,展示了整个批次数据的键;然后,详细列出了 'ego_dict' 中的键和信息;接着,显示了 'others_dict' 列表的长度以及第一个元素的键和信息;最后展示了 'others_len' 的值,提供了对数据批次中智能驾驶汽车和其他智能体信息的详尽了解。

dict_keys(['ego_dict', 'others_dict', 'others_len'])
ego_dict keys
dict_keys(['image', 'target_positions', 'target_yaws', 'target_availabilities', 'history_positions', 'history_yaws', 'history_availabilities', 'world_to_image', 'raster_from_world', 'raster_from_agent', 'agent_from_world', 'world_from_agent', 'track_id', 'timestamp', 'centroid', 'yaw', 'extent'])
10
dict_keys(['target_positions', 'target_yaws', 'target_availabilities', 'history_positions', 'history_yaws', 'history_availabilities', 'world_to_image', 'raster_from_world', 'raster_from_agent', 'agent_from_world', 'world_from_agent', 'track_id', 'timestamp', 'centroid', 'yaw', 'extent', 'dataset_idx', 'dist_to_ego'])
tensor([10,  7,  7,  2,  5,  9,  4,  4, 10,  5,  6,  4,  5, 10,  6, 10, 10,  5,
         9,  9,  1, 10,  5,  3])

  • 43
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农三叔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值