torchvision中CIFAR数据集的分割
众所周知,调用一个类的时候,通过调用类中的__getitem__()
方法来实现。
一般数据集切片过程可以通过dataset[0:100], dataset传入切片类型来实现,在python3中,基本上会继承slice()方法
对于数据集中trainset[2:10就相当于trainset.__getitem__(slice(2, 10))
但是CIFAR数据集中不能通过slice来调用,看源码
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
因此,我们可以通过另一种方式来调用,基于pytorch中random_split方法来实现
train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
lengths = [train_size, test_size]
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(dataset, lengths)