转载:https://www.freesion.com/article/3728236956/
PyTorch提供了两个数据原语:torch.utils.data.DataLoader:
在 周围包装一个可迭代对象Dataset
,以便轻松访问样本。torch.utils.data.Dataset:
存储样本及其相应的标签 ,被封装进DataLoader里,实现该方法封装自己的数据和标签。
Dataset
需要三个方法:
__init__ ()函数在实例化 Dataset 对象时运行一次
_getitem_()
函数从给定索引处的数据集中加载并返回一个样本index
_len_()函数
返回我们数据集中的样本数
例子:
import torch
import numpy as np
# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label):
self.data = data_root
self.label = data_label
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
def __len__(self):
return len(self.data)
# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)
使用DataLoader 进行训练
提供对Dataset
的操作,该Dataset
检索我们的数据的功能,并在同一时间标签一个样本。在训练模型时,我们通常希望以“小批量”形式传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Pythonmultiprocessing
来加速数据检索操作如下:
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
各参数含义:
dataset: 加载torch.utils.data.Dataset对象数据
batch_size: 每个batch的大小
shuffle:是否对数据进行打乱
drop_last:是否对无法整除的最后一个datasize进行丢弃
num_workers:表示加载的时候子进程数
读取数据
from torch.utils.data import DataLoader
# 读取数据
datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)
查看数据
我们可以通过迭代器(enumerate)
进行输出数据,测试如下:
for i, data in enumerate(datas):
# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
print("第 {} 个Batch \n{}".format(i, data))
版权声明:本文为l8947943原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/l8947943/article/details/103733473