Pytorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。
from torch.utils.data import Dataset, DataLoader
如果想要继承Datasets,父类中的两个私有成员函数必须被重载。
def getitem(self, index):
def len(self):
Datasets的源代码解说:
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
Datasets的整体思路:
class FirstDataset(data.Dataset): # 需要继承data.Dataset
def __init__(self):
# TODO
# 1. 初始化文件路径或文件名列表。
#也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
pass
def __getitem__(self, index):
# TODO
# 1。从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
# 2。预处理数据(例如torchvision.Transform)。
# 3。返回数据对(例如图像和标签)。
# 这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# 更改为数据集的总大小。
DataLoader类实现以下功能:
batch_size:可以分批次读取
shuffle=True可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
num_workers=2可以并行加载数据(利用多核处理器加快载入数据的效率
batch :可以分批次读取:batch-size
在识别手写数字的例子中,创建一个读取小批量数据样本的DataLoader实例
train_set = datasets.FashionMNIST('D:/', train=True, transform=transform, download=False)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)