torch.utils.data.Dataset
Dataset
是数据集的抽象类,需要实现__getitem__ 和__len__:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
其中的子类TensorDataset
考虑每个样本都将通过索引第一个维度上的张量来检索:
class TensorDataset(Dataset):
"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Ar