PyTorch数据加载方法
数据集介绍
本文使用的数据集为开源的本文分类数据集SMS Spam Collection Data Set,下载地址为https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据集是从 Grumbletext 网站手动提取了 425 条 SMS 垃圾邮件的集合,由一个文本文件构成,其中每一行都是有一个类别和后面的原始消息构成。
Dataset类的使用详解
对与不同类型的数据集所需要的设置不同的Dataset,通常通过继承pytorch中的Dataset类进而构建模型训练需要的dataset。下面以SMS数据集为例,Dataset类的使用方法:
class MyDataset(Dataset):
def __init__(self):
self.lines = open(data_path, encoding='UTF-8').readlines()
def __getitem__(self, index):
# 获取索引对应位置的一条数据
cur_line = self.lines[index].strip() # strip取消换行符
label = cur_line[:4].strip()
content = cur_line[4:].strip()
return label, content
def __len__(self):
# 返回数据总数
return len(self.lines)
可以根据自己的数据集的实际情况来修改这三个方法:
- __init__方法可以用来设置读取数据集等初始化数据集的基本操作
- __getitem__方法通常用来根据索引来返回一条对应的数据内容
- __len__方法通常用来返回数据总数
使用如下代码展示一下读取后的效果:
my_dataset = MyDataset()
print(my_dataset[0])
print(len(my_dataset))
其中每一条数据是以一个元组的形式保存在Dataset数据集中,元组的第一个元素为标签,第二个元素为数据内容。
DataLoader类的使用详解
DataLoader的主要作用是将Dataset处理后的数据集进行加载整合成batch用于后续训练,使用方法如下:
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset=my_dataset, batch_size=2, shuffle=True, num_workers=2)
DataLoader类主要参数如下:
- dataset:经过Dataset类处理过的数据集
- batch_size:一个batch中包含几条数据
- shuffle:是否打乱顺序
- sample:用于自定义从数据集中抽取样本的策略与方法,每次返回一个随机的索引,与shuffle互斥,如果使用了shuffle,则无法使用 sample
- batch_sampler :与shuffle类似,但是每次返回一批随机的索引,与batch_size、shuffle、sample和drop_last互斥
- num_work:多线程加速读取数据
- pin_memory:是否将数据放在dataloader返回前将Tensors复制到设备或者CUDA夹层内存中
- drop_last:用于判断是否放弃最后一个不完整的batch
使用for循环来展示一下效果:
for i in data_loader:
print(i)
break
当batch_size为2时可以看出这个列表中有两个元组,每个元组有条个数据。第一个元组存放标签,第二个元组存放着数据的内容。
在实际项目中通常使用enumerate方法在读取每一个batch内容的同时也返回其batch的索引:
for index, (label, content) in enumerate(data_loader):
print(index, label, content)
break
效果如下: