数据格式
XX.jpg lable
1.jpg 1
Code
data_transforms = {
'train': transforms.Compose([
#transforms.RandomHorizontalFlip(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid': transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
class OccGenDataset(Dataset):
def __init__(self, split, img_path, anno_txt):
with open(anno_txt, 'r') as f:
lines = f.readlines()
self.img_list = [os.path.join(img_path, i.split()[0]) for i in lines]
self.label_list = [i.split()[1] for i in lines]
self.transformer = data_transforms[split]
def __getitem__(self, i):
img = Image.open(self.img_list[i])
label = self.label_list[i]
img = self.transformer(img)
#print(self.img_list[i], self.label_list[i])
return img, int(label)
def __len__(self):
return len(self.img_list)