PyTorch将数据集的处理过程标准化,提供了Dataset基本的数据 类,并在torchvision中提供了众多数据变换函数,数据加载的具体过程 主要分为3步:
1.继承Dataset类
对于数据集的处理,PyTorch提供了torch.utils.data.Dataset这个抽象 类,在使用时只需要继承该类,并重写__len__()和__getitem()__函数, 即可以方便地进行数据集的迭代。
from torch.utils.data import Dataset
class my_data(Dataset):
def _init_(self,image_path,annotation_path,transform-None):
#初始化,读取数据集
def _len_(self):
#获取数据集的总大小
def _getitem_(self,id):
#对于指定的id,读取数据并返回
对上述初始化的·实际使用:
dataset = my_data("your image path", "your annotation path")
# 实例化该类 for data in dataset:
print(data)
2.数据变换与增强:torchvision.transforms
第一步虽然将数据集加载到了实例中,但在实际应用时,数据集中 的图片有可能存在大小不一的情况,并且原始图片像素RGB值较大 (0~255),这些都不利于神经网络的训练收敛,因此还需要进行一些 图像变换工作。PyTorch为此提供了torchvision.transforms工具包,可以 方便地进行图像缩放、裁剪、随机翻转、填充及张量的归一化等操作, 操作对象是PIL的Image或者Tensor。
如果需要进行多个变换功能,可以利用transforms.Compose将多个 变换整合起来,并且在实际使用时,通常会将变换操作集成到Dataset的 继承类中。具体示例如下:
from torchvision import transforms
#将transform集成到dataset类中,使用compose将多个变换整合到一起
dataset = my_data("your image path","your annotation path",transforms=transforms.Compose([
transforms.Resize(256) #将图像最短边缩小至256,宽高比例不变
#以0.5的概率随即翻转指定的PIL图像
transforms.RandomHorizaontalFlip()
#将PIL图像转为Tensor,元素区间从[0,255]归一化到[0,1]
transforms.ToTensor()
#进行mean与std为0.5的标准化
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
]))
3.继承dataloader
经过前两步已经可以获取每一个变换后的样本,但是仍然无法进行 批量处理、随机选取等操作,因此还需要torch.utils.data.Dataloader类进 一步进行封装,使用方法如下例所示,该类需要4个参数,第1个参数是 之前继承了Dataset的实例,第2个参数是批量batch的大小,第3个参数是 是否打乱数据参数,第4个参数是使用几个线程来加载数据。
from torch.utils.data import Dataloader
# 使用Dataloader进一步封装Dataset
dataloader = Dataloader(dataset, batch_size=4, shuffle=True, num_workers=4)
dataloader是一个可迭代对象,对该实例进行迭代即可用于训练过程。
data_iter = iter(dataloader)
for step in range(iters_per_epoch):
data = next(data_iter) # 将data用于训练网络即可