图片文件在同一的文件夹下
思路是继承 torch.utils.data.Dataset,并重点重写其 __getitem__方法,示例代码如下:
class ImageFolder(Dataset):
def __init__(self, folder_path):
self.files = sorted(glob.glob('%s/*.*' % folder_path))
def __getitem__(self, index):
path = self.files[index % len(self.files)]
img = np.array(Image.open(path))
h, w, c = img.shape
pad = ((40, 40), (4, 4), (0, 0))
# img = np.pad(img, pad, 'constant', constant_values=0) / 255
img = np.pad(img, pad, mode='edge') / 255.0
img = torch.from_numpy(img).float()
patches = np.reshape(img, (3, 10, 128, 11, 128))
patches = np.transpose(patches, (0, 1, 3, 2, 4))
return img, patches, path
def __len__(self):
return len(self.files)
图片文件在不同的文件夹下
比如我们有数据如下:
─── data
├── train
│ ├── 0.jpg
│ └── 1.jpg
├── test
│ ├── 0.jpg
│ └── 1.jpg
└── val
├── 1.jpg
└── 2.jpg
此时我们只需要将以上代码稍作修改即可,修改的代码如下:
self.files = sorted(glob.glob('%s/**/*.*' % folder_path, recursive=True))
其他代码不变。
reference