【PyTorch】数据加载

  • pytorch使用torch.utils.data对常用的数据加载进行封装,可以实现多线程预读取批量加载
  • 主要包括两个方面:1)把数据包装成Dataset类;2)用DataLoader加载。
  • TensorDataset可以直接接受Tensor类型的输入,并用DataLoader进行加载;省去自定义的过程。

官方数据集

  • torchvision中实现了一些常用的数据集,可以通过torchvision.datasets直接调用。如:MNIST,COCO,Captions,Detection,LSUN,ImageFolder,Imagenet-12,CIFAR,STL10,SVHN,PhotoTour。
  • torchvision.transforms提供了许多图像操作,可以很方便的进行数据增强。

一个典型的CIFAR10数据加载过程如下:

import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 可以加入更多数据增强处理

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

自定义数据加载

如果不使用官方数据集,想要加载自己的数据集,需要自定义一个dataset类。它需要继承torch.utils.data.Dataset类,并实现__getitem__()__len__()两个成员方法。

下面是一个自定义的视频数据集的例子:

from torch.utils.data import Dataset

class FrameDataset(Dataset):
    # 初始化时候要有一个数据加载,可以是数据路径的列表,或者直接把数据全加载进来。
    # 后者实际上没有用到批量加载的功能,需要注意内存占用。
    def __init__(self, data_dir, transform): 
        with open(data_dir, 'r') as fr:
            reader = csv.reader(fr)
            self.video_files = [video for video, label in reader]
        self.transform = transform
        print("dataset size: ", len(self.video_files))
    def __getitem__(self, index):
        video_file = self.video_files[index]
        # 读取视频的帧并返回
        return imgs, num_img
    def __len__(self):
        return len(self.video_files)

注意:

  • 如果数据类型是图像、视频,我们可以把原始数据保存在一个文件夹中,再用一个列表保存图像或视频的路径。这样数据集初始化时候加载的其实只是路径列表,在训练和测试时才会分批把原始数据读入。
  • 如果原始数据是直接以数字形式存储在一个文件中,无法通过索引单个读取,可以在初始化时候把整个矩阵读入,然后每次getitem时返回其中一行。

自定义的一大优点是处理更灵活,例如对于视频或文本数据,getitem函数中返回的帧序列或句子序列往往是长度不固定的,默认情况下DataLoaderstack时会出错,这时可以用collate_fn指定batch数据的连接方式:

def collate_fn(batch):
    imgs, num_img = zip(*batch)
    return torch.cat(imgs), num_img

然后就可以正常加载数据了:

dataset = FrameDataset(csv_file, transform=tfms)
videoloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

numpy类型数据加载

如果数据需要一次性全部读入,而且不需要额外的复杂处理的话可以不用自定义数据集Dataset类。

比如通常情况下,我们的输入可以很容易处理成一个numpy类型。这时可以不用定义Dataset类,直接使用TensorDataset,只要把读入的数据转化成一个tensor传入即可。

random_split是一个可以自动划分数据集的函数,实现随机不重复划分的功能。

from torch.utils.data import TensorDataset,DataLoader,random_split

dataset = TensorDataset(torch.from_numpy(data))

n_train = int(len(dataset) * 0.9)
n_test = len(dataset) - n_train
trainset, testset = random_split(dataset, [n_train, n_test])
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值