pytorch数据加载

1 为什么使用加载器?

在深度学习中,数据量通常是非常大的,面对如此大的数据,不可能一次性在模型中进行向前的计算或者反向传播,通常我们会对数据进行打乱顺序,把数据处理成一个个batch,同时会对数据进行预处理。

2 数据集类

2.1 Dataset基类

在 torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载

torch.util.data.Daatast源码如下

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:

  1. init方法,能够实现通过全局len()方法获取其中的元素个数
  2. getitem方法,能够通过传入索引方式获取数据,例如通过dataset[i]获取对应数据

2.2数据加载案例¶

数据来源:https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
如果网站打不开,点这个链接免费下载:
自然语言处理SMSSpamCollection数据集(免费分享)
数据介绍:SMS Spam Collection用于骚扰短信识别的经典数据集,完全来自于真实短信内容,包括了4831个正常短信和747个骚扰短信。 数据保存在一个txt文件中,每行完整记录一条内容,每行开通通过ham和spam来标识正常短信和骚扰短信。

# 完成数据集

class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()
    def __getitem__(self,index):
        # 获取索引对应位置的一条数据
        cur_line =  self.lines[index].strip()
        label = cur_line[:4].strip()
        content = cur_line[4:].strip()
        return label,content
    def __len__(self):
        # 返回数据总数量
        return len(self.lines)
my_dataset = MyDataset()
print(my_dataset[0])

输出结果

(‘ham’, ‘Go until jurong point, crazy… Available only in bugis n great world la e buffet… Cine there got amore wat…’)
包含了短信的标签ham(sham)和短信内容

但是至此我们并没有完成上文提到的

对数据进行打乱顺序,把数据处理成一个个batch,同时会对数据进行预处理。

所以我们需要引入下一个内容迭代数据集合。

3. 迭代数据集

作用:
1. 批处理数据 Batching the data
2. 打乱数据 Shuffling the data
3. 使用多线程multiprocessing 进行数据加载

在pytorch中torch.utils.DataLoader提供了上述所用的方法
DataLoader使用实例:

from torch.utils.data import DataLoader

data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)

参数含义:
1.dataset:提前定义好的dataset实例
2.batch_size:传入数据的batch的大小,常用128,256
3.shuffle:bool类型,是否打乱数据
4.num_worker:加载数据的线程数量

for i in data_loader:
    print(i)

输出结果:注意因为上面batch_size为2,所以一次迭代会输出两个。
在这里插入图片描述

# 自己构造一个迭代器查看迭代的结果
# enumerate 定义一个迭代器
for index,(label,context) in enumerate(data_loader):
    print(index,label,context)
    print("*"*100)
len(data_loader)

dataloader的长度应该为dataset的长度除以batch_size(批数量)
len(dataloader) = 向上取整(dataset/batch_size)

4.pytorch自带数据集合

pytorch中自带是数据集由两部分组成,分别是torchvision(图片数据)和torchtext(文本数据)
例如:
torchvision.dataset.MNIST(手写数字识别)
torchtext.dataset.IMDB(电影评论文本数据)

from torchvision.datasets import MNIST
mnist = MNIST(root="./data",train = True,download = True,)
print(mnist[0])
mnist[0][0].show()

参数介绍:

  1. root 参数表示数据存放的位置
  2. train:bool,表示是否使用训练集数据,False表示使用测试集数据
  3. download:bool,是否下载到root目录
  4. transform:实现对图片的处理

输出结果:
(<PIL.Image.Image image mode=L size=28x28 at 0x19F1E14F490>, 5)
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ZATuTu丶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值