目录
1. train.py
文件分析
train.py
是整个项目的训练脚本,负责模型的训练和验证流程。它使用 PyTorch Lightning 来简化训练过程,自动处理分布式训练、检查点保存等任务。
主要功能:
-
导入模块:
- 导入必要的模块,如 PyTorch、PyTorch Lightning、模型和数据集等。
- 从
TCP.model
导入自定义模型TCP
,从TCP.data
导入数据集类CARLA_Data
。
-
TCP_planner 类:
- 继承自
pl.LightningModule
,这个类封装了模型、损失函数、优化器等训练的核心部分。 _load_weight
方法:从检查点中加载预训练模型的权重,用于初始化模型参数。training_step
方法:定义每一步训练时的前向传播和损失计算。通过使用 PyTorch 的Beta
分布计算动作预测误差(KL 散度)、速度误差、路径点误差等。validation_step
方法:验证步骤类似于训练步骤,只是不进行反向传播。计算验证损失。configure_optimizers
方法:配置优化器为 Adam,并定义学习率调度策略。
- 继承自
-
命令行参数解析:
- 使用
argparse
模块从命令行传入训练相关的参数(如学习率、批次大小、GPU 数量等)。
- 使用
-
训练过程:
- 创建
CARLA_Data
数据集,并生成训练集和验证集的DataLoader
。 - 使用 PyTorch Lightning 的
Trainer
来管理训练流程,包括分布式训练(通过DDPPlugin
)和保存最佳模型(通过ModelCheckpoint
)。
- 创建
关键点:
- 该文件是整个训练过程的主入口,控制着模型训练、验证以及检查点保存。
- 通过继承
LightningModule
,简化了训练步骤和优化器配置的复杂性。
2. resnet.py
文件分析
resnet.py
实现了 ResNet 模型的定义。ResNet 是一种广泛用于图像分类、特征提取等任务的卷积神经网络。
主要功能:
-
卷积操作定义:
conv3x3
:定义了 3x3 的卷积操作,主要用于提取局部特征。conv1x1
:定义了 1x1 的卷积操作,主要用于改变通道数。
-
BasicBlock 和 Bottleneck:
BasicBlock
:这是 ResNet-18 和 ResNet-34 所用的基本模块。它包含了两个 3x3 的卷积层,通过残差连接保留输入信息。Bottleneck
:这是 ResNet-50、ResNet-101 和 ResNet-152 使用的瓶颈模块。它通过 1x1、3x3 和 1x1 卷积层扩展和压缩通道,减少计算复杂度。
-
ResNet 类:
- 该类定义了如何通过堆叠多个
BasicBlock
或Bottleneck
来构建 ResNet 网络。 _make_layer
:负责构建每一层的残差块。forward
方法:定义了前向传播的过程,包括图像通过卷积层、池化层、全连接层的顺序。
- 该类定义了如何通过堆叠多个
-
预训练模型加载:
- 使用
torch.hub.load_state_dict_from_url
函数下载并加载 ImageNet 上预训练的 ResNet 权重。
- 使用
关键点:
- 该文件实现了 ResNet 网络的多个变体(如 ResNet-18、ResNet-50 等),并允许加载预训练模型。
- 这些模型通常用于提取图像特征,后续的控制模块会使用这些特征进行预测。
3. augment.py
文件分析
augment.py
文件主要用于图像数据增强,使用了 imgaug
库来实现多种数据增强操作,增加数据的多样性以提高模型的泛化能力。
主要功能:
-
hard
函数:- 这个函数根据
image_iteration
的值调整图像增强的强度,定义了多种增强方式,如模糊、噪声、亮度调整、对比度调整等。 - 增强操作:
GaussianBlur
:添加高斯模糊,模糊强度随着迭代次数增加。AdditiveGaussianNoise
:向图像中添加高斯噪声。Dropout
和CoarseDropout
:随机丢弃图像中的部分像素,模拟遮挡和损坏。Add
和Multiply
:随机调整图像的亮度。Grayscale
:将图像转换为灰度。
- 这个函数根据
-
hard_1
函数:- 与
hard
函数类似,也定义了一系列增强操作,参数不同,用于构建稍微不同的增强效果。
- 与
关键点:
- 这些增强操作有助于模拟自动驾驶场景中的多种光照和传感器噪声情况,使模型对不同环境下的输入数据更加鲁棒。
4. config.py
文件分析
config.py
文件定义了全局配置类 GlobalConfig
,用于管理模型训练中的超参数和路径设置。
主要功能:
-
数据路径:
root_dir_all
定义了数据集的根目录,并通过循环将各个训练集和验证集的文件夹路径添加到train_data
和val_data
中。
-
模型与控制器超参数:
- 定义了控制器的比例、积分、微分系数(
turn_KP
,speed_KP
等),这些用于自动驾驶控制信号的生成。 seq_len
和pred_len
控制输入序列和输出路径点的长度。speed_weight
,value_weight
,features_weight
控制各个损失项的权重。
- 定义了控制器的比例、积分、微分系数(
-
图像处理参数:
- 包含图像的分辨率、裁剪大小、缩放比例等参数,用于图像预处理。
关键点:
- 该文件便于实验中的配置管理,可以通过
kwargs
动态调整不同实验的参数。 - 通过集中管理配置,能够方便地调整模型训练中的各类参数。
5. data.py
文件分析
data.py
文件定义了数据集加载类 CARLA_Data
,该类继承自 PyTorch 的 Dataset
类,主要用于加载和预处理 CARLA 仿真平台的数据。
主要功能:
-
数据加载:
- 通过
np.load
加载.npy
文件中的训练数据,包括图像、车辆状态、控制信号、未来轨迹等。 - 支持图像增强,通过
augment.py
中的增强函数在加载图像时对图像进行增强处理。
- 通过
-
图像预处理:
- 使用
torchvision.transforms
进行图像的归一化处理(ToTensor 和 Normalize)。
- 使用
-
__getitem__
方法:- 该方法返回指定索引的数据项,包括输入图像、未来轨迹点、控制信号等。
- 对于每个输入图像,系统会计算它相对于当前状态的局部目标点和未来轨迹,并返回相关数据。
-
辅助函数:
scale_and_crop_image
:用于图像的缩放和裁剪。transform_2d_points
和rot_to_mat
:用于坐标变换,将全局坐标系中的点转换为车辆局部参考系中的点。get_action_beta
:用于从 Beta 分布中计算动作。
关键点:
- 该文件通过
CARLA_Data
类封装了数据的加载和预处理,便于训练过程中高效读取和增强数据。 - 数据集包括输入图像、路径点、未来的轨迹和控制信号,支持路径规划任务的训练。
总结:
train.py
:训练和验证过程的核心脚本。resnet.py
:实现 ResNet 特征提取网络。augment.py
:图像数据增强。config.py
:全局配置文件,管理超参数和数据路径。data.py
:数据加载和预处理,主要用于从 CARLA 仿真平台加载自动驾驶数据。
每个文件在整个项目中都有其独特的作用,并且通过良好的模块化设计,这些文件紧密合作,共同实现了自动驾驶路径规划的训练与验证过程。如果你有更多问题或想要更深入