Pytorch中的数据加载

1. 模型中使用数据加载器的目的

在深度学习中,数据量通常是非常大的,不能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个batch,同时还会对数据进行预处理。

2. 数据集类

2.1 Dataset基类介绍

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们就能够非常快速的实现对数据的加载。torch.utils.data.Dataset的源码如下:
在这里插入图片描述
可知:我们需要再自定义的数据集类中集成Dataset类,同时还需要实现两个方法:

  • __len__方法,能够实现通过全局的len()方法获取其中的元素个数
  • __getitem__方法,能够通过传入索引的方式获取数据,如dataset[i]获取其中的第i条数据。

2.2 数据加载案例

数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据介绍:SMS Sapm Collection是用于骚扰短信识别的经典数据集,包括4831条正常短信和747条骚扰短信,正常短信和骚扰短信保存在一个文本文件中,每行完整记录一条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信。
数据实例:
在这里插入图片描述
实现如下:

import torch
from torch.utils.data import Dataset

filePath = "SMSSpamCollection"

class My_dataset(Dataset):
    def __init__(self):
        self.lines = open(filePath,'r',encoding='utf-8').readlines()

    def __getitem__(self, index):
        # 获取索引对应位置的一条数据
        return self.lines[index]

    def __len__(self):
        # 返回数据的总数量
        return len(self.lines)
if __name__ == '__main__':
    my_dataset = My_dataset()
    print(my_dataset[0])

3. 迭代数据集

使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:

  • 批处理数据(Batching the data)
  • 打乱数据(Shuffling the data)
  • 使用多线程multiprocessing并行加载数据
    在pytorch中torch.utils.data.DataLoader提供了上述的所用方法
from torch.utils.data import DataLoader
my_dataset = My_dataset()
data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True,num_workers=2,drop_last=True)
for index,(label,content) in enumerate(data_loader):
    print(index,label,content)

DataLoader参数的含义:

  1. dataset:提前定义的dataset的实例
  2. batch_size:传入数据的batch的大小
  3. shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
  4. num_workers:加载数据的线程数
  5. drop_last:如果最后一组不是完整的batch,就把最后的一组删除

4. pytorch自带的数据集

pytorch中自带的数据集由两个上层API提供,分别是torchvision和torchtext。
其中:

  1. torchvision提供了对图片数据处理相关的API和数据
  • 数据位置:torchvision.datasets,如:torchvision.datasets.MNIST(手写数字图片数据)
  1. torchtext提供了对文本数据处理相关的API和数据
  • 数据位置:torchtext.datasets,如:torchtext.datasets.IMDB(电影评论文本数据)
    使用方法和之前一暗影:
    1、准备好Dataset实例
    2、把dataset交给dataloader打乱顺序,组成batch

4.1 torchversion.dataset

torchversion.datasets 中的数据集类都是继承自Dataset,也就是说直接对torchvision.datasets,MNIST进行实例化就可以得到Dataset的实例,但是MNIST API 中的参数需要注意:

torchvision.dataset.MNIST(root='/files/',train = True,download = True,transform)
  1. root:参数表示数据存放的位置
  2. train:bool类型,表示是使用训练集的数据还是测试集的数据
  3. download:bool类型,表示是否需要下载数据到root目录
  4. transform:实现对图片的处理函数
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值