【pytorch】模型训练前数据加载
自学笔记,最新版本24.3.5
1 数据预处理、数据增广——torchvision.transforms
1.1 容器——transforms.Compose
同时对多种数据变换进行组合。
1.2 标准化—— transforms.Normalize(mean, std)
- 标准化原始数据的均值(Mean)和标准差(Standard Deviation)来进行数据的标准化,在经过标准化变换之后,数据全部符合均值为0、标准差为1的标准正态分布。
- mean和std是实现从原始数据计算出来的。
1.3 图像大小缩放——transforms.Resize(size)
1.4 随机概率p进行水平翻转—— transforms.RandomHorizontalFlip(p=)
1.5 随机概率p进行垂直翻转—— transforms.RandomVerticalFlip(p=)
1.6随机旋转一定角度——transforms.RandomRotation(degree=)
degree:加入degree是10,就是表示在(-10,10)之间随机旋转,如果是(30,60),就是30度到60度随机旋转
1.7 类型转换——transforms.ToTensor
转为pytorch可计算的
1.8 实例
import torchvision.transforms as T
transform = T.Compose([
T.Resize((args.input_size, args.input_size)), # 输入的尺寸
# 训练集需要数据增强
T.RandomHorizontalFlip(p=1), # p概率随机水平翻转
T.RandomRotation(degrees=10), # 随机旋转,degress是角度范围
T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.3), # 使用3X3的高斯滤波,以0.3概率进行变换
T.ColorJitter(brightness=0.4, contrast=0.4), # 添加随机的亮度和对比度增强
T.ToTensor(),# 将[0,255]归一化到[0,1],并且数据的shape从[H,W,C]变为[C,H,W]
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) # 标准化,mwan是均值,std是标准差
])
2 私人数据集读取数据——MyDataSet类
- 需要写三个函数,初始化、数据长度、实例对象通过下表索引函数:
def __init__(self): # 初始化
def __len__(self): # 长度
def __getitem__(self, idx): # 索引
- init:初始化数据,包括获取图像、标签和预处理(transformers)
- len:返回一个数据集的长度
- getitem:以idx作为位置索引下标,读取到该位置的图像、处理好的标签等信息并返回。
- 调用方式:MyDataset类作为DataLoader的dataset参数值,在loder初始化时会一同初始化,一般通过遍历函数enumerate(loader)隐式触发该类的函数getitem,返回索引下的图像和标签。
- 流程:
- 创建MyDataset类的实例对象TrainDataset
- DataLoader函数参数dataset赋值TrainDataset
通过for in enumerate(loader):的方式遍历调用数据
3 记录模型训练过程中指标(metrics)的工具类——MetricLogger类
在util中的工具类
3.1 初始化实例对象——misc.MetricLogger()
metric_logger = misc.MetricLogger(delimiter=" ")
3.2 计算一系列数值的平滑值——misc.SmoothedValue(window_size=,fmt)
3.3 添加计量器——add_meter(name=,fmt=)
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
3.4 控制记录指标的频率——log_every()
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
metric_logger.log_every(data_loader, print_freq, header)
- data_loader:数据集
- print_freq:记录频率
- header:记录形式
3.5 更新指标值——update()
3.6 同步多进程的状态——metric_logger.synchronize_between_processes()
- 在多个进程之间同步 MetricLogger 对象的状态。
- 这个方法通常在多进程训练中使用,用于确保不同进程中的 MetricLogger 对象记录的指标值是一致的。
4 DataFramel类行和列操作——.loc[:,:]和.iloc[:,:]
- iloc:使用0-len(list)的下标作为索引,类同数组的下标索引,只能是数字;
- loc:实际设置的索引,可以是字符,也可以是数字;
5 数据集分层抽样,保证类别分布比例——StratifiedShuffleSplit
- Scikit-learn 中的一个用于交叉验证的类,允许在划分数据集时保持类别的分布比例。具体来说,它可以用于将数据集随机分成训练集和测试集,并且在分割时保持每个类别样本的比例。
- 例子:
from sklearn.model_selection import StratifiedShuffleSplit
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in splitter.split(X, y):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
- 参数:
6 分布式训练的数据采样器——torch.utils.data.DistributedSampler
- 分布式数据采样: 在分布式训练环境中,每个进程或设备都需要处理不同的数据,DistributedSampler 可以确保每个进程或设备获取到的数据都是不重复的,从而避免了重复训练数据导致模型过拟合的问题。
- 数据加载的平衡性: DistributedSampler 还可以根据数据集的大小和设备的数量等信息,合理地对数据进行分配,以确保每个进程或设备获取到的数据量是平衡的,从而提高训练的效率和性能。
- 示例:
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
7 加载数据——torch.utils.data.DataLoader
- 主要用于构建数据管道,将数据集提供给模型进行训练或推理。它可以处理数据的加载、批处理、随机化、多线程数据加载等任务。
- 示例:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
for inputs, targets in dataloader:
# 在这里进行模型的训练或推理
8 数据加载流程总结
数据加载之前要处理的内容,主要是构建数据集dataset_train和dataset_val以及分布式采样Sampler:
1)数据信息csv文件:将图像的地址、分类作为csv保存,用于读取;
2)构建MyDataset类的数据集:编写util/dataset.py文件,包括预处理相关的transformer类,MyDataset类(初始化从csv中读取图像和标签数据,getitem函数返回索引自动调用的数据)
3)分层抽样Split:从csv中获取等比的训练测试集索引
4)将分层抽样后的数据集导入dataset.py中的函数构建MyDataset类的实例对象dataset_train和dataset_val。
5)分布式采样torch.utils.data.DistributedSampler:设置数据加载使用的sampler
6)构建 torch.utils.data.DataLoader,已经获得数据集和分布式采样,其他的参数来源于args。