pytorch为我们提供了Dataset类来提供所用数据集的创建任务。
数据集有两种情况:
1.pytorch中写好的数据集,如CIFAR10,我们在使用该数据集时只需要以下代码:data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
datasets.CIFAR10就是Datasets的一个子类,data是这个类的一个实例。
2.利用Dataset自定义数据集:
模板为
class MovingMNISTdataset(Dataset):#需要继承Dataset类
##dataset class for moving MNIST data
##Initialize
def __init__(self, path):
self.path = path
self.data = MNISTdataLoader(path)
def __len__(self):
return len(self.data[:, 0, 0, 0])
def __getitem__(self, indx):
##getitem method
self.trainsample_ = self.data[indx, ...]
self.sample_ = self.trainsample_/255.0
self.sample = torch.from_numpy(np.expand_dims(self.sample_, axis = 1)).float()
return self.sample
其中 getitem(self, index), len(self) 两个内建方法,用来表示从索引到样本的映射(Map).