pytorch 实现可以加载图像和对应mask的ImageFolder
目的
当在分类任务中训练数据集中样本不仅包含有原始图像,还有对应的mask时,例如需要同时加载原始图像和语义分割结果。pytorch自带的数据加载函数无法满足这种需求,而我们不想进行大的改动时,采用以下方式可以实现该功能。
代码实现
我们在pytorch自带的torchvision.datasets.ImageFolder文件基础上,通过对部分地方进行改动来实现加载图像对。
- 首先可以将原始的ImageFolder文件代码复制到新的py文件中;
- 对DatasetFolder函数进行改动,代码如下所示;
class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/[...]/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
mask_loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
self.mask_root = os.path.join(self.root, 'mask')
self.image_root = os.path.join(self.root, 'image')
print(self.mask_root)
print(self.image_root)
classes, class_to_idx = self._find_classes(self.image_root)
image_samples = self.make_dataset(self.image_root, class_to_idx, extensions, is_valid_file)
mask_samples = self.make_dataset(self.mask_root, class_to_idx, extensions, is_valid_file)
# samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(image_samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.image_root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = loader
self.mask_loader = mask_loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
# self.samples = samples
self.image_samples = image_samples
self.mask_samples = mask_samples
self.targets = [s[1] for s in image_samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
# path, target = self.samples[index]
image_path, target = self.image_samples[index]
mask_path, target = self.mask_samples[index]
# sample = self.loader(path)
image_sample = self.loader(image_path)
mask_sample = self.mask_loader(mask_path)
if self.transform is not None:
sample = self.transform(image_sample, mask_sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.image_samples)```
- 对ImageFolder函数进行改动,代码如下:
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
mask_loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
mask_loader: Callable[[str], Any] = png_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(root, loader, mask_loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.image_samples
- 再额外的添加一个png_loader函数,代码如下:
def png_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('P')
- 对transforms函数进行修改,使得其能够同时对图像和mask同步处理,详细代码实现参考:
transforms同步处理图像和mask代码实现
注:上述代码在pytorch1.8.0版本调试通过