class VOC12ClsDataset(VOC12ImageDataset):
def __init__(self, img_name_list_path, voc12_root, transform=None):
super().__init__(img_name_list_path, voc12_root, transform)
self.label_list = load_image_label_list_from_npy(self.img_name_list)
#self.label_list = load_image_label_list_from_xml(self.img_name_list, self.voc12_root)
def __getitem__(self, idx):
name, img = super().__getitem__(idx)
label = torch.from_numpy(self.label_list[idx])
return name, img, label
return name, img, label中返回的是list列表形式,第一个位置是name,第二位置是img,第三位置是label