torchvision.datasets.ImageFolder

目录

继承关系

初始化方法:

        一:find_classes

二:make_dataset

三: 写一个验证函数

四:loader

五:

六: __getitem__:

总结:


继承关系

class ImageFolder(DatasetFolder):

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
        self.imgs = self.samples

啥事没干

class DatasetFolder(VisionDataset):

    def __init__(
        self,
        root: str,
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
        """Generates a list of samples of a form (path_to_sample, class).



        return find_classes(directory)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

初始化方法:

        一:find_classes

classes, class_to_idx = self.find_classes(self.root)

具体实现主要是:

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

打断点进去看一下:

 总的来说,就是根据路径:得到,文件名,数字索引。当然它将文件名表示为类别。

二:make_dataset

samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

参数说明: root: 文件地址

  class_to_idx:类别索引

extensions:图片后缀 ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

is_valid_file: 是一个可调用的函数: Optional[Callable[[str], bool]] 

返回:图片路径,和类别(索引)

                if is_valid_file(path):
                    item = path, class_index
                    instances.append(item)

如果验证通过则会加到返回中,反之不会。

三: 写一个验证函数

class Check:
    def __init__(self, key):
        print('看看值')
        print(key)

    def __call__(self, *args, **kwargs):
        return True

使用:直接传进去

class TestDataset(torchvision.datasets.ImageFolder):
    # 根路径,
    def __init__(self, root, imgsz, cache, augment):
        super().__init__(root=root, is_valid_file=Check)

结果:

 传的是一个图片地址,可以拿到图片做一些验证工作。

四:loader

self.loader = loader

也是一个回调函数

loader: Callable[[str], Any],

默认提供的是:

def default_loader(path: str) -> Any:
    from torchvision import get_image_backend

    if get_image_backend() == "accimage":
        return accimage_loader(path)
    else:
        return pil_loader(path)
def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")

一个读取图片的方法而已。

五:

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

我们其实只关心samples 图片和targers 标签

六: __getitem__:

        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

总结:

如果要重写,我觉得主要就是getitem方法。

首先:类给主要给我们提供了,文件读取的方法。我们可以直接拿到文件路径集合。

有了文件路径,就没必要用它的文件加载方法。yolo中有更高效的方法,如果是目标检测,我们可以重新指定标签,比如yolo中规定标签和图片名一样,便于找到。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值