在学习这部分内容时,可以先明白几个术语的意思,DataSize、BatchSize、Iteration、Epoch分别是什么意思,他们之间的关系如何。
如图可知,所有数据训练一次就是一个Epoch,一次输入模型的数据大小就是一个Batch,处理一次batchsize就是Iteration。
Dataset:
首先继承dataset,
from torch.utils.data import DataLoader, Dataset
,主要包括三部分
class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# 此处编辑文件的路径,获取到一个包含各个文件路径的列表,
#注意区分训练集与测试集的路径,一般设置个bool值True是取训练集
pass
def __getitem__(self, index):
# 此处打开单个文件,例如文本文件或者图片文件,一般返回文件和标签
pass
def __len__(self):
#返回一个长度,可以是所有文件的个数
return len(total_file)
DataLoad:
使用dataload一般可以编写一个get_dataloader(train = True