pyTorch根据filelist加载自己的数据集合,无论图片是否在一个文件夹还是一个类的图片在一个文件夹。
第一步:继承实现Dataset类别
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
else:
img = Tensor.from_numpy(img)
return img,label
def __len__(self): return len(self.imgs)
第二步骤:就直接可以用自己定义的这个类,来构建自己的dataset了
transform = transforms.Compose([transforms.Scale((227,227)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
train_data = MyDataset(txt='train_filelist.txt',transform=transform)
其中比较有用的一个点是
transforms.Scale((227,227))
用来将不同大小的图片resize到统一尺寸。
还有一个点就是,彩色图片都要做的归一化
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])