pytorch之数据集构造

这些天看的东西,真的是比较多,相比以前来说,对我的学习方式起到颠覆性作用。我目前觉得,我们学到的东西,更多是孤立的,因此,在吸收一定知识后,需要在脑子里形成知识体系。需要把自己以前学到的东西进行整理,形成一个体系,这篇文章讲解的是,深度学习中pytorch数据集的构造!!!

pytorch中有两个自定义管理数据集的类,

  1. torch.utils.data.DataSet
  2. torvchvision.datasets.ImageFolder

这里主要讲解的第一种。

DataSet源码

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

我们设计自己数据集类的时候, 只需要重写 __getitem__、__len__两个函数,分别的功能是, 通过切片返回具样例返回样本个数
以下是voc2012数据集分割的例子:

import os

import numpy as np
from PIL import Image
from torch.utils import data


def read_images(root, train):
    txt_fname = os.path.join(root, 'ImageSets/Segmentation/') + ('train.txt' if train else 'val.txt')
    with open(txt_fname, 'r') as f:
        images = f.read().split()
    data = [os.path.join(root, 'JPEGImages', i + '.jpg') for i in images]
    label = [os.path.join(root, 'SegmentationClass', i + '.png') for i in images]
    return data, label


class VocSegDataset(data.Dataset):

    def __init__(self, cfg, train, transforms=None):
        self.cfg = cfg
        self.train = train
        self.transforms = transforms
        self.data_list, self.label_list = read_images(self.cfg.DATASETS.ROOT, train)

    def __getitem__(self, item):
        img = self.data_list[item]
        label = self.label_list[item]
        img = Image.open(img)
        # load label
        label = Image.open(label)
        img, label = self.transforms(img, label)
        return img, label

    def __len__(self):
        return len(self.data_list)

通过上面的操作,我们构建自己数据集类,接下来,构建一个 Dataloader类,这个作用是训练过程中,返回 batch个样例。

Dataloder源码

由于源码过于臃肿了,这里知识摘出对应的构造函数:

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None):

构造函数中,每个参数的意思就不一一介绍了,只着重的讲解下,可调用函数collate_fn。我们首先看一个构建Dataloader的实例:

def build_dataset(cfg, transforms, is_train=True):
    datasets = VocSegDataset(cfg, is_train, transforms)
    return datasets


def make_data_loader(cfg, is_train=True):
    if is_train:
        batch_size = cfg.SOLVER.IMS_PER_BATCH
        shuffle = True
    else:
        batch_size = cfg.TEST.IMS_PER_BATCH
        shuffle = False

    transforms = build_transforms(cfg, is_train)
    datasets = build_dataset(cfg, transforms, is_train)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    data_loader = data.DataLoader(
        datasets, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True
    )

    return data_loader

上面第一个函数 build_dataset返回数据集实例,第二个函数返回Dataloader,关于Dataloader,我们需要注意的是,有时我们需要根据Dataset中的__getitem__修改collate_fn
我们来看下源码:

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

我们在源码中发现,collate_fn的输入是一个list,里面的每个元素是__getitem__的输出,由此,我们估计,default_collate的作用是将这个list,**变换格式为[batch,C,H,W]**的tensor,我们在来看下源码:

	if.......
	.........
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):  # namedtuple
        return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))

由于源码均是对类型的判断,因此,这里我们知识摘出,与voc2012分割相关的部分,这个语句的意思是, 对[(img1, label1), (img2, label2)],首先返回[img1,img2],[lable1,label2],在继续返回两个tensor,一个是img,[batch,C,H,W],一个是label:[batch,C,H,W]。
所以,通过上面分析,如果,我们__getitem__不符合collat_fn不符合默认函数的判断时,需要修改该函数。
好了,先到这,接下来…慢慢聊程序,需要学的太多了

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中加载数据集通常有两种常见的方法:使用自定义数据集和使用预定义数据集。 1. 使用自定义数据集: - 创建一个新的Python类,继承`torch.utils.data.Dataset`,并实现`__len__`和`__getitem__`方法。`__len__`返回数据集的大小,`__getitem__` 根据给定索引返回样本。 - 在`__init__`方法中,根据需求加载数据集并对其进行预处理。 - 可以使用PyTorch提供的各种数据转换方法(例如`torchvision.transforms`)来对数据进行预处理。 - 在训练代码中,实例化自定义数据集类,并使用`torch.utils.data.DataLoader`将数据加载到训练循环中。 下面是一个简单的自定义数据集加载示例: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] # 在这里进行数据预处理 return torch.Tensor(sample) # 假设有一个包含样本的列表 data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) for batch in dataloader: # 在这里执行训练循环 print(batch) ``` 2. 使用预定义数据集: - PyTorch提供了一些预定义的数据集,如`torchvision.datasets`模块中的MNIST、CIFAR10等。 - 可以使用预定义数据集构造函数来加载数据集,并根据需要进行转换和预处理。 - 同样,可以使用`torch.utils.data.DataLoader`将数据加载到训练循环中。 下面是一个预定义数据集加载示例: ```python import torch import torchvision from torchvision import transforms # 定义数据转换和预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # 加载MNIST数据集 train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform) # 使用DataLoader加载数据集 train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) for batch in train_dataloader: # 在这里执行训练循环 images, labels = batch print(images.shape, labels.shape) ``` 这些是基本的加载数据集的方法,你可以根据自己的需求进行修改和扩展。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值