len 方法,提供了dataset的大小; getitem 方法, 该方法支持从 0 到 len(self)的索引。
#filenames是训练数据文件名称列表,labels是标签列表
class MyDataset(Dataset):
def __init__(self, filenames, labels, transform):
self.filenames = filenames
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
image = Image.open(self.filenames[idx]).convert('RGB')
image = self.transform(image)
return image, self.labels[idx]
这样,就得到了Dataset类型的数据,接下来可以直接用Dataloader加载了。