【NLP理论到实战】06 pytorch中的数据加载

Pytorch中的数据加载

目标

  1. 知道数据加载的目的
  2. 知道pytorch中Dataset的使用方法
  3. 知道pytorch中DataLoader的使用方法
  4. 知道pytorch中的自带数据集如何获取

1. 模型中使用数据加载器的目的

在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。


但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。


所以,接下来我们来学习pytorch中的数据加载的方法

2. 数据集类Dataset

2.1 Dataset基类介绍

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。

torch.utils.data.Dataset的源码如下:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

可知:我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:

  1. __len__方法,能够实现通过全局的len()方法获取其中的元素个数
  2. __getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据

2.2 数据加载案例

下面通过一个例子来看看如何使用Dataset来加载数据
数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据介绍:SMS Spam Collection是用于骚扰短信识别的经典数据集,完全来自真实短信内容,包括4831条正常短信和747条骚扰短信。正常短信和骚扰短信保存在一个文本文件中。 每行完整记录一条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信

数据实例:

在这里插入图片描述

实现如下:
注:Python strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列

from torch.utils.data import Dataset
import pandas as pd

load_path = r"./data/SMSS/SMSSpamCollection"

class mysmsDataset(Dataset):
    def __init__(self):
        mydata = open(load_path, 'r', encoding='UTF-8')
        lines = [[line[:4].strip(), line[4:].strip()] for line in mydata]  # 对数据进行处理,前4个为label,后面的为短信内容
        self.df = pd.DataFrame(lines, columns=['label', 'sms'])  # 转化为dataFrame

    def __getitem__(self, index):
        single_item = self.df.iloc[index, :]
        return single_item.values[0], single_item.values[1]

    def __len__(self):
        return self.df.shape[0]

之后对继承Dataset基类的mysmsDataset自定义子类进行实例化,可以获取其中的数据

sms = mysmsDataset()
print(sms[0],'\n',sms[5])
print(len(sms))

输出如下:

('ham', 'Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...') 
 ('spam', "FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv")
5574

3. 数据加载器类DataLoader

使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:

  • 批处理数据(Batching the data)
  • 打乱数据(Shuffling the data)
  • 使用多线程 multiprocessing 并行加载数据。

在pytorch中torch.utils.data.DataLoader提供了上述的所用方法
DataLoader的使用方法示例:

from torch.utils.data import DataLoader

sms = mysmsDataset()
smsdata_loader = DataLoader(dataset=sms,batch_size=7,shuffle=True,drop_last=False)

#遍历,获取其中的每个batch的结果
for index, (label, context) in enumerate(smsdata_loader):
    print(index,label,context)
    break  # 这里数据过多就看一个符不符合

其中参数含义:

  1. dataset:提前定义的dataset的实例
  2. batch_size:传入数据的batch的大小,常用128,256等等
  3. shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
  4. num_workers:加载数据的线程数(如果num_workers的值大于0,要在运行的部分放进__main__()函数里,才不会有错)

数据迭代器的返回结果如下:

0 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham') ('Shall i send that exe to your mail id.', 'Bring tat cd don forget', 'Babe ! How goes that day ? What are you doing ? Where are you ? I sip my cappuccino and think of you, my love ... I send a kiss to you from across the sea', 'Er yeah, i will b there at 15:26, sorry! Just tell me which pub/cafe to sit in and come wen u can', 'A lot of this sickness thing going round. Take it easy. Hope u feel better soon. Lol', "I'm always on yahoo messenger now. Just send the message to me and i.ll get it you may have to send it in the mobile mode sha but i.ll get it. And will reply.", "Hello, my boytoy! I made it home and my constant thought is of you, my love. I hope your having a nice visit but I can't wait till you come home to me ...*kiss*")

注意:

  1. len(dataset) = 数据集的样本数
  2. len(dataloader) = math.ceil(样本数/batch_size) 即向上取整
from math import ceil

print(len(sms))
print(len(smsdata_loader))  # 向上取整,如果drop_last=False不去掉最后一个batch(默认为True)
print(len(sms)/7)
print(ceil(len(sms)/7))
5574
797
796.2857142857143
797

4 pytorch自带的数据集

pytorch中自带的数据集由两个上层api提供,分别是torchvisiontorchtext
其中:

  1. torchvision提供了对图片数据处理相关的api和数据
    数据位置:torchvision.datasets,例如:torchvision.datasets.MNIST(手写数字图片数据)
  2. torchtext提供了对文本数据处理相关的API和数据
    数据位置:torchtext.datasets,例如:torchtext.datasets.IMDB(电影评论文本数据)

下面我们以Mnist手写数字为例,来看看pytorch如何加载其中自带的数据集
使用方法和之前一样:

  1. 准备好Dataset实例
  2. 把dataset交给dataloder 打乱顺序,组成batch

4.1 torchversion.datasets

torchversoin.datasets中的数据集类(比如torchvision.datasets.MNIST),都是继承自Dataset
意味着:直接对torchvision.datasets.MNIST进行实例化就可以得到Dataset的实例

但是MNIST API中的参数需要注意一下:
torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)

  1. root参数表示数据存放的位置
  2. train:bool类型,表示是使用训练集的数据还是测试集的数据
  3. download:bool类型,表示是否需要下载数据到root目录
  4. transform:实现的对图片的处理函数

4.2 MNIST数据集的介绍

数据集的原始地址:http://yann.lecun.com/exdb/mnist/


MNIST是由Yann LeCun等人提供的免费的图像识别的数据集,其中包括60000个训练样本和10000个测试样本,其中图拍了的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28X28

执行代码,下载数据,观察数据类型:

import torchvision

mnist_dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=None)

print(mnist_dataset)
print(mnist_dataset[0])

下载的数据如下:
在这里插入图片描述

代码输出结果如下:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
(<PIL.Image.Image image mode=L size=28x28 at 0x265F4B2AD60>, 5)

可以其中数据集返回了两条数据,可以猜测为图片的数据和目标值
返回值的第0个为Image类型,可以调用show() 方法打开,发现为手写数字5

img = mnist_dataset[0][0]
img.show() #打开图片

图片如下:
在这里插入图片描述
由上可知:返回值为(图片,目标值),这个结果也可以通过观察源码得到

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值