学习内容
深度学习数据加载流程:
# 创建自定义数据集, 建立索引到数据样本的映射
myData = MyDataset(**args)
# 通过DataLoader()加载myData,以指定方式从数据集中迭代生成 batch 样本集合
dataLoader = DataLoader(dataset, batch_size, shuffle, num_works)
# DataLoader为模型提供训练数据,根据sampler指定的策略生成训练样本
for i in range(epoch):
for idx, (sequence, ans) in enumerate(dataLoader):
pass
数据加载核心类:torch.utils.data.DataLoader
# DataLoader的构造函数参数配置
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
- dataset (Dataset) – 从中加载数据的数据集。
- batch_size (int, optional) – 每批要加载多少个样本(默认值:1)。
- shuffle (bool, optional) – 设置为 True时, 让数据在每个 epoch 重新洗牌(默认值:False)。
- sampler(Sampler 或 Iterable,可选)——定义从数据集中抽取样本的策略。可以是任何实现了 len 的 Iterable。如果指定,则不得指定 shuffle。
- batch_sampler(Sampler 或 Iterable,可选)- 类似于 sampler,但一次返回一批索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
- num_workers (int, optional) – 用于数据加载的子进程数。0 表示数据将在主进程中加载。(默认值:0)。
数据集:torch.utils.data.Dataset
表示数据集的抽象类,任何自定义的数据集都要继承这个类,并重写相关方法。
Pytorch支持两种不同类型的数据集:
- 映射类型数据集
所有映射类型数据集子类都应该重写__getitem__(self, index) ,支持获取给定键的数据样本。重写__len__(self) ,返回数据集的大小;表示索引到数据样本的映射。
class MyDataset(Dataset):#需要继承Dataset
def __init__(self):
# 初始化文件路径,文件名称
pass
def __getitem__(self, index):
# 读取数据 读取的是一个样本,而不是全部数据
# 数据预处理
# 返回data pair
pass
def __len__(self):
# 返回数据集大小
return len(self.data)
-
迭代类型数据集
Mark
采样器sampler:torch.utils.data.sampler.Sampler
PyTorch提供的Sampler
- SequentialSampler
- RandomSampler
- SubsetRandomSampler
- WeightedRandomSampler
详见: https://www.csdn.net/
也可以自己定义采样器:自定义时要继承 torch.utils.data.sampler.Sampler 抽象类。
我们在训练时常用的是对批量数据训练,而BatchSampler
的作用就是将前面的Sampler
采样得到的索引值合并成一个batch并返回。