文章目录
Pytorch中的数据加载
目标
- 知道数据加载的目的
- 知道pytorch中Dataset的使用方法
- 知道pytorch中DataLoader的使用方法
- 知道pytorch中的自带数据集如何获取
1. 模型中使用数据加载器的目的
在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。
但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。
所以,接下来我们来学习pytorch中的数据加载的方法
2. 数据集类Dataset
2.1 Dataset基类介绍
在torch中提供了数据集的基类
torch.utils.data.Dataset
,继承这个基类,我们能够非常快速的实现对数据的加载。
torch.utils.data.Dataset
的源码如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
可知:我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:
__len__
方法,能够实现通过全局的len()
方法获取其中的元素个数__getitem__
方法,能够通过传入索引的方式获取数据,例如通过dataset[i]
获取其中的第i
条数据
2.2 数据加载案例
下面通过一个例子来看看如何使用Dataset来加载数据
数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据介绍:SMS Spam Collection是用于骚扰短信识别的经典数据集,完全来自真实短信内容,包括4831条正常短信和747条骚扰短信。正常短信和骚扰短信保存在一个文本文件中。 每行完整记录一条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信
数据实例:
实现如下:
注:Pythonstrip()
方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列
from torch.utils.data import Dataset
import pandas as pd
load_path = r"./data/SMSS/SMSSpamCollection"
class mysmsDataset(Dataset):
def __init__(self):
mydata = open(load_path, 'r', encoding='UTF-8')
lines = [[line[:4].strip(), line[4:].strip()] for line in mydata] # 对数据进行处理,前4个为label,后面的为短信内容
self.df = pd.DataFrame(lines, columns=['label', 'sms']) # 转化为dataFrame
def __getitem__(self, index):
single_item = self.df.iloc[index, :]
return single_item.values[0], single_item.values[1]
def __len__(self):
return self.df.shape[0]
之后对继承Dataset基类的mysmsDataset自定义子类进行实例化,可以获取其中的数据
sms = mysmsDataset()
print(sms[0],'\n',sms[5])
print(len(sms))
输出如下:
('ham', 'Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...')
('spam', "FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv")
5574
3. 数据加载器类DataLoader
使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:
- 批处理数据(Batching the data)
- 打乱数据(Shuffling the data)
- 使用多线程
multiprocessing
并行加载数据。
在pytorch中
torch.utils.data.DataLoader
提供了上述的所用方法
DataLoader
的使用方法示例:
from torch.utils.data import DataLoader
sms = mysmsDataset()
smsdata_loader = DataLoader(dataset=sms,batch_size=7,shuffle=True,drop_last=False)
#遍历,获取其中的每个batch的结果
for index, (label, context) in enumerate(smsdata_loader):
print(index,label,context)
break # 这里数据过多就看一个符不符合
其中参数含义:
- dataset:提前定义的dataset的实例
- batch_size:传入数据的batch的大小,常用128,256等等
- shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
num_workers
:加载数据的线程数(如果num_workers的值大于0,要在运行的部分放进__main__()函数里,才不会有错)
数据迭代器的返回结果如下:
0 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham') ('Shall i send that exe to your mail id.', 'Bring tat cd don forget', 'Babe ! How goes that day ? What are you doing ? Where are you ? I sip my cappuccino and think of you, my love ... I send a kiss to you from across the sea', 'Er yeah, i will b there at 15:26, sorry! Just tell me which pub/cafe to sit in and come wen u can', 'A lot of this sickness thing going round. Take it easy. Hope u feel better soon. Lol', "I'm always on yahoo messenger now. Just send the message to me and i.ll get it you may have to send it in the mobile mode sha but i.ll get it. And will reply.", "Hello, my boytoy! I made it home and my constant thought is of you, my love. I hope your having a nice visit but I can't wait till you come home to me ...*kiss*")
注意:
len(dataset) = 数据集的样本数
len(dataloader) = math.ceil(样本数/batch_size) 即向上取整
from math import ceil
print(len(sms))
print(len(smsdata_loader)) # 向上取整,如果drop_last=False不去掉最后一个batch(默认为True)
print(len(sms)/7)
print(ceil(len(sms)/7))
5574
797
796.2857142857143
797
4 pytorch自带的数据集
pytorch中自带的数据集由两个上层api提供,分别是
torchvision
和torchtext
其中:
torchvision
提供了对图片数据处理相关的api和数据
数据位置:torchvision.datasets
,例如:torchvision.datasets.MNIST
(手写数字图片数据)torchtext
提供了对文本数据处理相关的API和数据
数据位置:torchtext.datasets
,例如:torchtext.datasets.IMDB(电影
评论文本数据)
下面我们以Mnist手写数字为例,来看看pytorch如何加载其中自带的数据集
使用方法和之前一样:
- 准备好Dataset实例
- 把dataset交给dataloder 打乱顺序,组成batch
4.1 torchversion.datasets
torchversoin.datasets
中的数据集类(比如torchvision.datasets.MNIST
),都是继承自Dataset
意味着:直接对torchvision.datasets.MNIST
进行实例化就可以得到Dataset
的实例
但是MNIST API中的参数需要注意一下:
torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)
root
参数表示数据存放的位置train:
bool类型,表示是使用训练集的数据还是测试集的数据download:
bool类型,表示是否需要下载数据到root目录transform:
实现的对图片的处理函数
4.2 MNIST数据集的介绍
数据集的原始地址:
http://yann.lecun.com/exdb/mnist/
MNIST是由Yann LeCun
等人提供的免费的图像识别的数据集,其中包括60000个训练样本和10000个测试样本,其中图拍了的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28X28
执行代码,下载数据,观察数据类型:
import torchvision
mnist_dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=None)
print(mnist_dataset)
print(mnist_dataset[0])
下载的数据如下:
代码输出结果如下:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
Dataset MNIST
Number of datapoints: 60000
Root location: ./data
Split: Train
(<PIL.Image.Image image mode=L size=28x28 at 0x265F4B2AD60>, 5)
可以其中数据集返回了两条数据,可以猜测为图片的数据和目标值
返回值的第0个为Image类型,可以调用show() 方法打开,发现为手写数字5
img = mnist_dataset[0][0]
img.show() #打开图片
图片如下:
由上可知:返回值为(图片,目标值)
,这个结果也可以通过观察源码得到