Dataset 只负责数据的抽象,一次只返回一个数据或者样本
Pytorch中数据集被抽象为一个抽象类torch.utils.data.Dataset,所有的数据集都应该继承这个类,并override以下两项:
__len__:代表样本数量。len(obj)等价于obj.__len__()。
__getitem__:返回一条数据或一个样本。obj[index]等价于obj.__getitem__。建议将节奏的图片等高负载的操作放到这里,因为多进程时会并行调用这个函数,这样做可以加速。
dataset中应尽量只包含只读对象,避免修改任何可变对象。因为如果使用多进程,可变对象要加锁,但后面讲到的dataloader的设计使其难以加锁。如下面例子中的self.num可能在多进程下出问题:
class BadDataset(Dataset):
def __init__(self):
self.datas = range(100)
self.num = 0 # read data times
def __getitem__(self, index):
self.num += 1
return self.datas[index]
Dataloader 前面提到过,在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助我们实现这些功能。
官方documentation
Dataset负责表示数据集&#