1 为什么使用加载器?
在深度学习中,数据量通常是非常大的,面对如此大的数据,不可能一次性在模型中进行向前的计算或者反向传播,通常我们会对数据进行打乱顺序,把数据处理成一个个batch,同时会对数据进行预处理。
2 数据集类
2.1 Dataset基类
在 torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载
torch.util.data.Daatast源码如下
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:
- init方法,能够实现通过全局len()方法获取其中的元素个数
- getitem方法,能够通过传入索引方式获取数据,例如通过dataset[i]获取对应数据
2.2数据加载案例¶
数据来源:https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
如果网站打不开,点这个链接免费下载:
自然语言处理SMSSpamCollection数据集(免费分享)
数据介绍:SMS Spam Collection用于骚扰短信识别的经典数据集,完全来自于真实短信内容,包括了4831个正常短信和747个骚扰短信。 数据保存在一个txt文件中,每行完整记录一条内容,每行开通通过ham和spam来标识正常短信和骚扰短信。
# 完成数据集
class MyDataset(Dataset):
def __init__(self):
self.lines = open(data_path,encoding='utf-8').readlines()
def __getitem__(self,index):
# 获取索引对应位置的一条数据
cur_line = self.lines[index].strip()
label = cur_line[:4].strip()
content = cur_line[4:].strip()
return label,content
def __len__(self):
# 返回数据总数量
return len(self.lines)
my_dataset = MyDataset()
print(my_dataset[0])
输出结果
(‘ham’, ‘Go until jurong point, crazy… Available only in bugis n great world la e buffet… Cine there got amore wat…’)
包含了短信的标签ham(sham)和短信内容
但是至此我们并没有完成上文提到的
对数据进行打乱顺序,把数据处理成一个个batch,同时会对数据进行预处理。
所以我们需要引入下一个内容迭代数据集合。
3. 迭代数据集
作用:
1. 批处理数据 Batching the data
2. 打乱数据 Shuffling the data
3. 使用多线程multiprocessing 进行数据加载
在pytorch中torch.utils.DataLoader提供了上述所用的方法
DataLoader使用实例:
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)
参数含义:
1.dataset:提前定义好的dataset实例
2.batch_size:传入数据的batch的大小,常用128,256
3.shuffle:bool类型,是否打乱数据
4.num_worker:加载数据的线程数量
for i in data_loader:
print(i)
输出结果:注意因为上面batch_size为2,所以一次迭代会输出两个。
# 自己构造一个迭代器查看迭代的结果
# enumerate 定义一个迭代器
for index,(label,context) in enumerate(data_loader):
print(index,label,context)
print("*"*100)
len(data_loader)
dataloader的长度应该为dataset的长度除以batch_size(批数量)
len(dataloader) = 向上取整(dataset/batch_size)
4.pytorch自带数据集合
pytorch中自带是数据集由两部分组成,分别是torchvision(图片数据)和torchtext(文本数据)
例如:
torchvision.dataset.MNIST(手写数字识别)
torchtext.dataset.IMDB(电影评论文本数据)
from torchvision.datasets import MNIST
mnist = MNIST(root="./data",train = True,download = True,)
print(mnist[0])
mnist[0][0].show()
参数介绍:
- root 参数表示数据存放的位置
- train:bool,表示是否使用训练集数据,False表示使用测试集数据
- download:bool,是否下载到root目录
- transform:实现对图片的处理
输出结果:
(<PIL.Image.Image image mode=L size=28x28 at 0x19F1E14F490>, 5)