pytorch——数据加载

1、模型中使用数据加载器的目的

        深度学习中,数据量通常是非常多、量非常大的,大量的数据,不可能一次性地在模型中进行向前的计算和反向传播,通常,我们会对整个数据进行随机的打乱数据,把数据处理成一个一个的batch,同时还会对数据进行预处理。

2、数据集类

2.1Dataset基类介绍

        在torch中,提供了数据集的基类torch.utils.data.Dataset,继承这个基类,能够很快地实现对数据的加载。

        torch.utils.data.Dataset的源码如下:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

        据此,我们可以知道,我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:

        __len__方法,能够实现通过全局的len()方法获取其中的元素个数。

        __getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]调取第i条数据。

2.2数据加载案例

        通过一个例子使用Dataset来加载数据。

        数据来源:UCI Machine Learning Repository: SMS Spam Collection Data Set

        数据介绍:SMS Spam Collection是用于骚扰短信识别的经典数据集,完全来自真实短信内容,包含4831条正常短信和747条骚扰短信,正常短信和骚扰短信保存在一个文本文件中,每行完整记录一条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信。内容如下:

         首先,我们实现数据集类:

import torch
from torch.utils.data import Dataset

data_path = r"D:\我的文档\daima\python\pytorch\mnist\短信数据\SMSSpamCollection"#加r标识是一个字符串,否则可能认为\后为特殊含义

#完成数据集
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()

    def __getitem__(self,index):
        #获取索引对应位置的一条数据
        return self.lines[index].strip()#除去换行符

    def __len__(self):
        #返回数据的总数量
        return len(self.lines)

       strip()作用是处理字符串左右两边的空白符(包括'\n', '\r',  '\t',  ' ')或指定的字符

       接着,我们可以对Dataset进行实例化,迭代获取其中的数据:

if __name__ == '__main__':
    my_dataset = MyDataset()
    for i in range(4):
        print(my_dataset[i])

        可以看到输出如下:

         此时,标签和文本内容并没有分开,所以对__getitem方法进行修改:

    def __getitem__(self,index):
        cur_line = self.lines[index].strip()
        label = cur_line[:4].strip()
        content = cur_line[4:].strip()
        return label,content

        输出是一个元组。

3、迭代数据集

        使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:

  • 批处理数据(Batching the data)
  • 打乱数据(Shuffling the data)
  • 使用多线程multiprocessing并行加载数据

        在pytorch中torch.utils.data.DataLoader提供了上述所有的方法

        实例:

import torch
from torch.utils.data import DataLoader,Dataset

data_path = r"D:\我的文档\daima\python\pytorch\mnist\短信数据\SMSSpamCollection"#加r标识是一个字符串,否则可能认为\后为特殊含义

#完成数据集
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()

    def __getitem__(self,index):
        cur_line = self.lines[index].strip()
        label = cur_line[:4].strip()
        content = cur_line[4:].strip()
        return label,content

    def __len__(self):
        #返回数据的总数量
        return len(self.lines)

my_dataset = MyDataset()
data_loader = DataLoader(dataset = my_dataset,batch_size = 2,shuffle = True,num_workers = 2)

if __name__ == '__main__':
    for i in data_loader:
        print(i)

        DataLoader()函数参数的含义:

  1. dataset:提前定义的dataset的实例;
  2. batch_size:传入数据的batch的大小,常用128,256等等;
  3. shuffle:bool类型,表示是否在每次获取数据时提前进行打乱;
  4. num_workers:加载数据的线程数。

        输出的部分结果如下:

         当然,遍历时可以使用enumerate关键字,将对象组成一个一个序列和索引,可以同时获得索引和索引值。需要打印数据的长度时,可以使用ceil函数(将给定的值转换成大于或等于它的最小整数)。

    for i in enumerate(data_loader):
        print(i)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值