CUDA Error解决
- CUDA out of memory ,在模型的训练过程中,可能一开始会报这个错,也可能运行几个迭代之后报这个错,具体报错解决如下:
--amp
AMP技术解释
- AMP(自动混合精度)是一个用于深度学习训练中的技术,它通过智能地使用半精度(浮点16)和全精度(浮点32)浮点运算来加速模型的训练并减少内存使用。
- 提高训练速度:
- 运算加速:半精度浮点数(FP16)比全精度浮点数(FP32)需要的内存更少。这允许GPU在相同时间内处理更多的数据
- 网络吞吐量提升:由于数据包大小减小,使用FP16可以在GPU上加载更多的数据和模型参数
- 减少内存使用:
- 节约内存: FP16相比FP32可以节约一半的内存,这对于大模型数据和大批量数据尤其重要,因为它们对显存的需求非常高
- 允许更大的批量大小:由于每个数据点使用的内存减少,可以在不超过GPU内存限制的情况下增加批量大小,这有助于更稳定和有效的训练
- 为什么AMP可以避免超出显存:
- 显存管理优化:AMP提供了动态的显存管理技术,这包括在必要时自动调整数据的精度。例如在某些操作中使用FP16,在需要高精度计算的操作中自动切换回FP32,这样既保证了计算精度,又优化了显存使用
- 避免内存溢出:通过减少单个数据元素的内存需求,AMP可以有效降低整体内存消耗,这样在相同的硬件条件下,可以运行更大的模型或者使用更大的数据批量。从而避免因内存不足而导致的程序崩溃或者性能下降
- 如何使用AMP:
- 使用torch.cuda.amp模块来实现自动混合精度,该模块提供了“autocast”、"GradScaler"这两个工具,其中“autocast”用于自动调整运算的数据类型,“GradScaler”用于调整梯度的规模,防止在使用FP16时因数值范围限制而造成的梯度下溢。
输入尺寸控制—图像预处理
- 输入图像在输入模型之前需要被缩放到一个统一的尺寸,同时保持图像原始的长宽比不变,以适应模型的输入需求,通常由–max_size_train、–min_size_train这两个参数控制
- 在ESTS中,我们通过–min_size_train来控制模型的最短边的尺寸,同时由于我们指定了多个min值,所以在每次迭代时,模型会随机从中选择一个使用。作为一种数据增强策略,可以帮助模型更好的学习从不同尺寸的图片中提取特征,增加模型的泛化性。
- 我们使用–max_size_train来控制模型的最长边的尺寸,模型会在对图片进行缩放后检查是否有最长边超过这个尺寸,如果有,则会将其缩小至max_size以满足最大尺寸限制。
- 举例说明:
1. 参数设置:--min_size_train 设置为 [640, 672, 704, 736, 768, 800, 832, 864, 896]。
2. 操作:在每次训练迭代开始时,从这个列表中随机选择一个尺寸值作为目标尺寸。假设在一个特定的训练迭代中选择了 800。
3. 图像处理:每个训练图像将被重新缩放,使得其最短边至少为 800 像素,同时保持图像原始的长宽比不变。
例如,如果原图是 1600x1200 像素,则缩放后的尺寸将为 1067x800 像素。
1. 参数设置:--max_size_train 设置为 1600。
2. 操作:在调整了最小尺寸之后,检查图像的最长边是否超过了 1600 像素。
3. 图像处理:如果最长边超过了 1600 像素,整个图像会进一步按比例缩小,直到最长边等于 1600 像素。继续上面的例子,
如果缩放到 1067x800 后,假设再应用其他变换使得图像增大到 1800x1350 像素,则会将其缩小至 1600x1200 像素以满足最大尺寸限制
1. 统一输入尺寸:神经网络通常要求输入具有统一的尺寸,通过这种方式,我们可以确保所有输入图像都满足模型的要求。
2. 增强数据:通过在一系列预定义的尺寸中随机选择,可以增加模型训练的多样性,帮助模型更好地泛化到不同尺寸的图像。
3. 性能优化:限制最大尺寸有助于避免因图像太大导致的计算和存储资源过度消耗。
Resume
- resume是一种允许从中断或先前的检查点重新开始训练的机制,它可以在训练过程中遇到意外情况后继续训练,而不是从头开始。
- 通过“–resume”参数来接收checkpoint文件路径,文件可能为.pt 、 .pth文件等
- 在训练过程中,模型的状态(模型的参数权重、优化器的状态、当前的训练轮次等)会周期性的保存为checkpoint检查点文件。这通常在每个epoch结束后自动发送,或者在特定的迭代次数后发生。
- checkpoint文件通常包含:
- 模型参数(权重和偏置)
- 优化器状态(例如,Adam优化器的动量和自适应学习率状态)
- 训练轮次或步数
- 其他可能影响训练继续进行的状态(例如学习率计划)
- 我们可以通过
torch.load()
来查看,pt文件中包含什么内容,由于加载后的对象通常是一个字典,所以我们可以根据字典中的键来确定文件包含哪些类型的数据。例如以下代码:
import torch
checkpoint = torch.load('path_to_checkpoint.pt')
print("Keys in the checkpoint:", checkpoint.keys())
if 'state_dict' in checkpoint:
print("The checkpoint contains model parameters.")
if 'optimizer' in checkpoint:
print("The checkpoint contains optimizer states.")
if 'epoch' in checkpoint:
print("The checkpoint contains epoch information.")