https://zhuanlan.zhihu.com/p/200876072
class DatasetSplit(Dataset): """An abstract Dataset class wrapped around Pytorch Dataset class. """ def __init__(self, dataset, idxs): self.dataset = dataset self.idxs = [int(i) for i in idxs] def __len__(self): return len(self.idxs) def __getitem__(self, item): image, label = self.dataset[self.idxs[item]] return torch.tensor(image), torch.tensor(label)