Pytorch加载数据集

目录

1.定义自己的dataset类

2.Sampler

3.定义自己的dataloader迭代器

4.遍历数据


1.定义自己的dataset类

需要继承torch.utils.data.Dataset类,并重写__init__(),__len__(), __getitem__()方法

数据增强操作(类或者函数)也在该类的__getitem方法中被调用

import torch
import numpy as np


# 继承Dataset方法,并重写__getitem__()和__len__()方法
class my_dataset(torch.utils.data.Dataset):
	# 初始化函数,得到数据
    def __init__(self, data_root, data_label, transform=None):
        self.data = data_root
        self.label = data_label
        self.transform = transform   #数据增强

    # inde是索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        if self.trannform:
            data, labels = self.transform(data, labels)
        return data, labels
    # 该函数返回数据大小长度,目的是DataLoader方便划分
    def __len__(self):
        return len(self.data)

# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = my_dataset(source_data, 
                        source_label,
                        transform=tranforms.compose([
                                    Rescale(256),
                                    ToTensOr()]))
class Rescale(object):
    """将图片调整为给定的大小.

    Args:
        output_size (tuple or int): 期望输出的图片大小. 如果是 tuple 类型,输出图片大小就是给定的 output_size;
                                    如果是 int 类型,则图片最短边将匹配给的大小,然后调整最大边以保持相同的比例。
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, data, labels):

        h, w = data.shape[:2]
        # 判断给定大小的形式,tuple 还是 int 类型
        if isinstance(self.output_size, int):
            # int 类型,给定大小作为最短边,最大边长根据原来尺寸比例进行调整
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(data, (new_h, new_w))

        # 根据调整前后的尺寸比例,调整关键点的坐标位置,并且 x 对应 w,y 对应 h
        labels = labels * [new_w / w, new_h / h]

        return img, labels

数据增强可以定义成类,而不是函数,这样就不需要每次都传递参数,为此需要实现__call__方法何__init__方法

2.Sampler

PyTorch为我们提供了几种现成的Sampler子类:

  • SequentialSampler
  • RandomSampler
  • SubsetRandomSampler
  • WeightedRandomSampler
  • BatchSampler
  • DistributedSampler

dataloader()中的shuffle=True时,默认的是RandomSampler,shuffle=false时默认的是SequentialSampler,一般不需要指定sampler,使用dataloader中默认指定的就行

if sampler is None:  # give default samplers
    if self._dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            # Cannot statically verify that dataset is Sized
            # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
            sampler = RandomSampler(dataset, generator=generator)  # type: ignore
        else:
            sampler = SequentialSampler(dataset)

if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

3.定义自己的dataloader迭代器

通过torch.utils.data.DataLoader实现

dataloader = DataLoader(my_dataset, batch_size=4, shuffle=True, num_workers=4)

4.遍历数据

直接for循环

for i, data in enumerate(dataloader):
    imgs, targets = data

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值