【pytorch】模型训练前数据加载——数据增强transforms/MyDataSet类/MetricLogger类/加载数据


自学笔记,最新版本24.3.5


1 数据预处理、数据增广——torchvision.transforms

学习:PyTorch之torchvision.transforms详解[原理+代码实现]

1.1 容器——transforms.Compose

同时对多种数据变换进行组合。

1.2 标准化—— transforms.Normalize(mean, std)

  1. 标准化原始数据的均值(Mean)和标准差(Standard Deviation)来进行数据的标准化,在经过标准化变换之后,数据全部符合均值为0、标准差为1的标准正态分布。
  2. mean和std是实现从原始数据计算出来的。

1.3 图像大小缩放——transforms.Resize(size)

1.4 随机概率p进行水平翻转—— transforms.RandomHorizontalFlip(p=)

1.5 随机概率p进行垂直翻转—— transforms.RandomVerticalFlip(p=)

1.6随机旋转一定角度——transforms.RandomRotation(degree=)

degree:加入degree是10,就是表示在(-10,10)之间随机旋转,如果是(30,60),就是30度到60度随机旋转

1.7 类型转换——transforms.ToTensor

转为pytorch可计算的

1.8 实例

import torchvision.transforms as T
transform = T.Compose([
                    T.Resize((args.input_size, args.input_size)),  # 输入的尺寸
                    # 训练集需要数据增强
                    T.RandomHorizontalFlip(p=1),  # p概率随机水平翻转
                    T.RandomRotation(degrees=10),  # 随机旋转,degress是角度范围
                    T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.3),  # 使用3X3的高斯滤波,以0.3概率进行变换
                    T.ColorJitter(brightness=0.4, contrast=0.4),  # 添加随机的亮度和对比度增强
                    T.ToTensor(),# 将[0,255]归一化到[0,1],并且数据的shape从[H,W,C]变为[C,H,W]
                    T.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225]) # 标准化,mwan是均值,std是标准差
            ])

2 私人数据集读取数据——MyDataSet类

  1. 需要写三个函数,初始化、数据长度、实例对象通过下表索引函数:
def __init__(self): # 初始化
def __len__(self):  # 长度
def __getitem__(self, idx):  # 索引
  1. init:初始化数据,包括获取图像、标签和预处理(transformers)
  2. len:返回一个数据集的长度
  3. getitem:以idx作为位置索引下标,读取到该位置的图像、处理好的标签等信息并返回。
  4. 调用方式:MyDataset类作为DataLoader的dataset参数值,在loder初始化时会一同初始化,一般通过遍历函数enumerate(loader)隐式触发该类的函数getitem,返回索引下的图像和标签。
  5. 流程:
  1. 创建MyDataset类的实例对象TrainDataset
  2. DataLoader函数参数dataset赋值TrainDataset
    通过for in enumerate(loader):的方式遍历调用数据

3 记录模型训练过程中指标(metrics)的工具类——MetricLogger类

在util中的工具类

3.1 初始化实例对象——misc.MetricLogger()

metric_logger = misc.MetricLogger(delimiter="  ")

3.2 计算一系列数值的平滑值——misc.SmoothedValue(window_size=,fmt)

在这里插入图片描述

3.3 添加计量器——add_meter(name=,fmt=)

metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))

在这里插入图片描述

3.4 控制记录指标的频率——log_every()

header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
metric_logger.log_every(data_loader, print_freq, header)
  1. data_loader:数据集
  2. print_freq:记录频率
  3. header:记录形式

3.5 更新指标值——update()

3.6 同步多进程的状态——metric_logger.synchronize_between_processes()

  1. 在多个进程之间同步 MetricLogger 对象的状态。
  2. 这个方法通常在多进程训练中使用,用于确保不同进程中的 MetricLogger 对象记录的指标值是一致的。

4 DataFramel类行和列操作——.loc[:,:]和.iloc[:,:]

Python学习.iloc和.loc区别、联系与用法
在这里插入图片描述

  1. iloc:使用0-len(list)的下标作为索引,类同数组的下标索引,只能是数字;
  2. loc:实际设置的索引,可以是字符,也可以是数字;

5 数据集分层抽样,保证类别分布比例——StratifiedShuffleSplit

  1. Scikit-learn 中的一个用于交叉验证的类,允许在划分数据集时保持类别的分布比例。具体来说,它可以用于将数据集随机分成训练集和测试集,并且在分割时保持每个类别样本的比例。
  2. 例子:
from sklearn.model_selection import StratifiedShuffleSplit
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in splitter.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
  1. 参数:
    在这里插入图片描述

6 分布式训练的数据采样器——torch.utils.data.DistributedSampler

  1. 分布式数据采样: 在分布式训练环境中,每个进程或设备都需要处理不同的数据,DistributedSampler 可以确保每个进程或设备获取到的数据都是不重复的,从而避免了重复训练数据导致模型过拟合的问题。
  2. 数据加载的平衡性: DistributedSampler 还可以根据数据集的大小和设备的数量等信息,合理地对数据进行分配,以确保每个进程或设备获取到的数据量是平衡的,从而提高训练的效率和性能。
  3. 示例:
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)

在这里插入图片描述

7 加载数据——torch.utils.data.DataLoader

  1. 主要用于构建数据管道,将数据集提供给模型进行训练或推理。它可以处理数据的加载、批处理、随机化、多线程数据加载等任务。
  2. 示例:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
for inputs, targets in dataloader:
    # 在这里进行模型的训练或推理

在这里插入图片描述


8 数据加载流程总结

数据加载之前要处理的内容,主要是构建数据集dataset_train和dataset_val以及分布式采样Sampler:
1)数据信息csv文件:将图像的地址、分类作为csv保存,用于读取;
2)构建MyDataset类的数据集:编写util/dataset.py文件,包括预处理相关的transformer类,MyDataset类(初始化从csv中读取图像和标签数据,getitem函数返回索引自动调用的数据)
3)分层抽样Split:从csv中获取等比的训练测试集索引
4)将分层抽样后的数据集导入dataset.py中的函数构建MyDataset类的实例对象dataset_train和dataset_val。
5)分布式采样torch.utils.data.DistributedSampler:设置数据加载使用的sampler
6)构建 torch.utils.data.DataLoader,已经获得数据集和分布式采样,其他的参数来源于args。

  • 15
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch是一个基于Python的科学计算包,其主要功能是进行张量计算和深度学习模型构建。在深度学习中,数据加载是一个重要的环节,PyTorch提供了一些工具和函数来简化数据加载的过程。 PyTorch数据加载主要涉及到两个:`torch.utils.data.Dataset`和`torch.utils.data.DataLoader`。其中,`Dataset`用于表示数据集,而`DataLoader`则用于对数据集进行加载和处理。 使用PyTorch进行数据加载的基本步骤如下: 1. 定义数据集:需要继承`torch.utils.data.Dataset`,并实现`__len__`和`__getitem__`方法。其中,`__len__`方法返回数据集的大小,`__getitem__`方法用于获取指定索引的数据。 2. 创建数据集实例:将定义好的数据集实例化,并传入相应的参数(如文件路径等)。 3. 创建数据加载器:使用`torch.utils.data.DataLoader`创建数据加载器,可以指定批次大小、是否打乱数据、多进程等参数。 4. 迭代数据:使用for循环迭代数据加载器,每次迭代返回一个批次的数据。 下面是一个简单的示例代码,用于加载MNIST数据集: ```python import torch from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms # 定义自己的数据 class MyDataset(Dataset): def __init__(self, path): self.data = torch.load(path) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) def __len__(self): return len(self.data) def __getitem__(self, index): x, y = self.data[index] x = self.transform(x) return x, y # 创建数据集实例 train_dataset = MyDataset('mnist/train.pt') test_dataset = MyDataset('mnist/test.pt') # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True) # 迭代数据 for batch_idx, (data, target) in enumerate(train_loader): # 对批次数据进行训练或测试 ... ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值