Pytorch中的数据加载

Pytorch中的数据加载

1. 模型中使用数据加载器的目的

在前面的线性回归模型中,使用的数据很少,所以直接把全部数据放到模型中去使用。

但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。

所以,接下来学习pytorch中的数据加载的方法

2. 数据集类

2.1 Dataset基类介绍

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。

torch.utils.data.Dataset的源码如下:

class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

可知:我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:

  1. __len__方法,能够实现通过全局的len()方法获取其中的元素个数

  2. __getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据

def __len__(self):
        raise NotImplementedError

2.2 数据加载案例

下面通过一个例子来看看如何使用Dataset来加载数据

数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection

数据介绍:SMS Spam Collection是用于骚扰短信识别的经典数据集,完全来自真实短信内容,包括4831条正常短信和747条骚扰短信。正常短信和骚扰短信保存在一个文本文件中。 每行完整记录一条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信

数据实例:

实现如下:

import pandas as pd
from torch.utils.data import Dataset

data_path = './SMSSpamCollection'


class CifarDataset(Dataset):
    def __init__(self):
        lines = open(data_path, 'r', encoding='utf-8')
        #  对数据进行处理,前4个为label,后面的为短信内容
        lines = [[i[:4].strip(), i[4:].strip()] for i in lines]
        #  转化为dataFrame
        self.df = pd.DataFrame(lines, columns=['label', 'sms'])

    def __getitem__(self, item):
        single_item = self.df.iloc[item, :]
        return single_item.values[0], single_item.values[1]

    def __len__(self):
        return self.df.shape[0]


d = CifarDataset()
for i in range(len(d)):
    print(i, d[i])

运行结果:

3. 迭代数据集

使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:

  • 批处理数据(Batching the data)

  • 打乱数据(Shuffling the data)

  • 使用多线程 multiprocessing 并行加载数据。

在pytorch中torch.utils.data.DataLoader提供了上述的所用方法

DataLoader的使用方法示例:

import pandas as pd
from torch.utils.data import Dataset, DataLoader

data_path = './SMSSpamCollection'


class CifarDataset(Dataset):
    def __init__(self):
        lines = open(data_path, 'r', encoding='utf-8')
        #  对数据进行处理,前4个为label,后面的为短信内容
        lines = [[i[:4].strip(), i[4:].strip()] for i in lines]
        #  转化为dataFrame
        self.df = pd.DataFrame(lines, columns=['label', 'sms'])

    def __getitem__(self, item):
        single_item = self.df.iloc[item, :]
        return single_item.values[0], single_item.values[1]

    def __len__(self):
        return self.df.shape[0]


dataset = CifarDataset()
data_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True, num_workers=2)

if __name__ == '__main__':  # 下面代码必须放在main函数内,否则报错
    #  遍历,获取其中的每个batch的结果
    for index, (label, context) in enumerate(data_loader):
        print(index, label, context)
        print("*" * 30)

其中参数含义:

  1. dataset:提前定义的dataset的实例

  2. batch_size:传入数据的batch的大小,常用128,256等等

  3. shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据

  4. num_workers:加载数据的线程数

数据迭代器的返回结果如下: :

556 ('spam', 'ham', 'ham', 'spam', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham') ('Todays Voda numbers ending with 7634 are selected to receive a £350 reward. If you have a match please call 08712300220 quoting claim code 7684 standard rates apply.', 'Lol they were mad at first but then they woke up and gave in.', 'Good Morning my Dear........... Have a great & successful day.', 'Do you want a NEW video phone750 anytime any network mins 150 text for only five pounds per week call 08000776320 now or reply for delivery tomorrow', "No. To be nosy I guess. Idk am I over reacting if I'm freaked?", "No we sell it all so we'll have tons if coins. Then sell our coins to someone thru paypal. Voila! Money back in life pockets:)", 'at bruce b downs & fletcher now', "Yes, i'm small kid.. And boost is the secret of my energy..", 'Nice.nice.how is it working?', 'U buy newspapers already?')
******************************
557 ('ham', 'spam', 'ham', 'spam') ('hi baby im cruisin with my girl friend what r u up 2? give me a call in and hour at home if thats alright or fone me on this fone now love jenny xxx', 'We tried to call you re your reply to our sms for a video mobile 750 mins UNLIMITED TEXT + free camcorder Reply of call 08000930705 Now', "Hiya , have u been paying money into my account? If so, thanks. Got a pleasant surprise when i checked my balance -u c, i don't get statements 4 that acc", 'Congrats! Nokia 3650 video camera phone is your Call 09066382422 Calls cost 150ppm Ave call 3mins vary from mobiles 16+ Close 300603 post BCM4284 Ldn WC1N3XX')
******************************

注意:

  1. len(dataset) = 数据集的样本数

  2. len(dataloader) = math.ceil(样本数/batch_size) 即向上取整

4 pytorch自带的数据集

pytorch中自带的数据集由两个上层api提供,分别是torchvisiontorchtext

其中:

  1. torchvision提供了对图片数据处理相关的api和数据

    • 数据位置:torchvision.datasets,例如:torchvision.datasets.MNIST(手写数字图片数据)

  2. torchtext提供了对文本数据处理相关的API和数据

    • 数据位置:torchtext.datasets,例如:torchtext.datasets.IMDB(电影评论文本数据)

下面以Mnist手写数字为例,来看看pytorch如何加载其中自带的数据集

使用方法和之前一样:

  1. 准备好Dataset实例

  2. 把dataset交给dataloder 打乱顺序,组成batch

4.1 torchversion.datasets

torchversoin.datasets中的数据集类(比如torchvision.datasets.MNIST),都是继承自Dataset

意味着:直接对torchvision.datasets.MNIST进行实例化就可以得到Dataset的实例

但是MNIST API中的参数需要注意一下:

torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)

  1. root参数表示数据存放的位置

  2. train:bool类型,表示是使用训练集的数据还是测试集的数据  【True是训练集】

  3. download:bool类型,表示是否需要下载数据到root目录

  4. transform:实现的对图片的处理函数

4.2 MNIST数据集的介绍

数据集的原始地址:http://yann.lecun.com/exdb/mnist/

MNIST是由Yann LeCun等人提供的免费的图像识别的数据集,其中包括60000个训练样本和10000个测试样本,其中图拍了的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28X28

执行代码,下载数据,观察数据类型:

 

import torchvision

dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=None)
print(dataset[0])

下载的数据如下:

代码输出结果如下:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
(<PIL.Image.Image image mode=L size=28x28 at 0x18D303B9C18>, tensor(5))

可以其中数据集返回了两条数据,可以猜测为图片的数据和目标值

返回值的第0个为Image类型,可以调用show() 方法打开,发现为手写数字5

import torchvision

dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=None)
print(dataset[0])

img = dataset[0][0]
#  打开图片
img.show() 

图片如下:

由上可知:返回值为(图片,目标值),这个结果也可以通过观察源码得到

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值