torchvision.datasets的三个基础类

总结

torchvision.dataset中为自定义数据集提供的三个基础类DatasetFolder, ImageFolder和VisonDataset, 这三者除了均为torch.utils.data.Dataset()的子类外,它们之间也存在继承关系。其中VisionDataset定义于datasets/vision.py,DatasetFolder和ImageFolder定义于dataset/folder.py。VisionDataset没有默认的__getitem__ 和__len__方法,DatasetFolder继承自VisionDataset,重写了了__getitem__ 和__len__方法,ImageFolder又继承自DatasetFolder。

Related Links

vision/torchvision/datasets at main · pytorch/vision (github.com)

Datasets — Torchvision 0.16 documentation (pytorch.org)

使用示例

CIFAR10

torchvision.datasets.CIFAR10继承自VisionDataset,重写了__getitem__ 和__len__方法,并且定义了_load_meta方法以实现类似find_classes方法的功能。

# 部分代码 具体代码参照https://pytorch.org/vision/0.16/_modules/torchvision/datasets/cifar.html#CIFAR10
class CIFAR10(VisionDataset):
    def _load_meta(self) -> None:
        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
        if not check_integrity(path, self.meta["md5"]):
            raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            self.classes = data[self.meta["key"]]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        img, target = self.data[index], self.targets[index]


        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

VOCSegmentation

torchvision.datasets.VOCSegmentation继承了_VOCBase,_VOCBase又继承自VisionDataset

class VOCSegmentation(_VOCBase):

    _SPLITS_DIR = "Segmentation"
    _TARGET_DIR = "SegmentationClass"
    _TARGET_FILE_EXT = ".png"

    @property
    def masks(self) -> List[str]:
        return self.targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images[index]).convert("RGB")
        target = Image.open(self.masks[index])

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

VisionDataset

继承自torch.utils.data.Dataset() 依然需要重写__getitem__ 和__len__方法

参数

  • root 数据集的根地址,仅用于重写__repr__

  • transforms 应用在一张图像和标签的变换,并且返回两者的变换版本

  • transform 应用在图像上的变换,返回变换后的版本

  • target_transform 应用在标签上的变换,返回变换后的版本

import os
import torch.utils.data as data

class VisionDataset(data.Dataset):
   
    _repr_indent = 4

    def __init__(
        self,
        root: str = None,  # type: ignore[assignment]
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        _log_api_usage_once(self)
        if isinstance(root, str):
            root = os.path.expanduser(root)
        self.root = root

        has_transforms = transforms is not None
        has_separate_transform = transform is not None or target_transform is not None
        if has_transforms and has_separate_transform:
            raise ValueError("Only transforms or transform/target_transform can be passed as argument")

        # for backwards-compatibility
        self.transform = transform
        self.target_transform = target_transform

        if has_separate_transform:
            transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms

    def __getitem__(self, index: int) -> Any:

        raise NotImplementedError

    def __len__(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        head = "Dataset " + self.__class__.__name__
        body = [f"Number of datapoints: {self.__len__()}"]
        if self.root is not None:
            body.append(f"Root location: {self.root}")
        body += self.extra_repr().splitlines()
        if hasattr(self, "transforms") and self.transforms is not None:
            body += [repr(self.transforms)]
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return "\n".join(lines)

    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
        lines = transform.__repr__().splitlines()
        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]

    def extra_repr(self) -> str:
        return ""

torchvision.datasets.folder

在这个文件中有DatasetFolder类,和DatasetFolder类默认调用的函数find_calsses() make_dataset()

find_calsses()

DatasetFolder中find_classes方法默认调用的函数,找到一个如下结构存储的数据集中的类别目录

directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── …
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── …
└── asd932_.ext

参数:

  • directory 数据集的根目录

返回:

  • classes 类别标签
  • class_to_idx 类别名和类别序号的对应关系
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:

    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

make_dataset()

参数

  • directory 数据集根目录
  • class_to_idx 类别标签和类别序号的关系 默认使用find_classes中的逻辑
  • extensions 允许的扩展名列表
  • is_valid_file 验证文件有效性的函数

返回

  • instance 样本的列表,其中样本为元组(path_to_sample, class)
def make_dataset(
    directory: str,
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Union[str, Tuple[str, ...]]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
  
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = path, class_index
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances

has_file_allowed_extension()

检查一个文件是否是允许的拓展名

参数

  • filename 文件的路径
  • extensions 扩展名的元组
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))

is_image_file()

检查一个文件是否是允许的图片扩展名

参数:

  • filename 文件的路径
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
def is_image_file(filename: str) -> bool:
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)

DatasetFolder()

位于torchvision.datasets.folder,继承自VisionDataset,通用的数据loader 目录结构可以通过重写find_classed方法自定义

__ init __()

参数

  • root 根目录
  • loader 从给定路径读取样本的函数
  • extensions 允许的扩展名列表
  • transform 对于样本施加的变换
  • target_transform 对于样本施加的变换
  • is_valid_file 读取文件路径并且验证文件有效性的函数
def __init__(
        self,
        root: str,
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

make_dataset()

参数

  • directory 数据集的根目录,对应self.root

  • class_to_idx 类名和类序号的映射字典

  • extensions 允许的扩展名列表

  • is_valid_file 验证文件有效性的函数

    class_to_idx参数不能为None,因为make_dataset()需要使用类内的find_classes方法,如果为None则class_to_idx会默认使用类外的find_classes函数,而类内方法是类外函数的重写,因此两者可能因为overridden而产生逻辑的不同

 def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
        
        if class_to_idx is None:
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

find_classes()

默认返回torchvision.datasets.folder中的find_classes函数,通过重写来对应不同数据集结构

参数

  • directory 路径根目录,对应self.root
 def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:

        return find_classes(directory)

__ getitem __()

def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

ImageFolder()

继承自 DatasetFolder 一个通用的数据loader 数据集默认按以下结构排列:

	root/dog/xxx.png
    root/dog/xxy.png
    root/dog/[...]/xxz.png

​ root/cat/123.png
root/cat/nsdf3.png
root/cat/[…]/asd932_.png

参数:

  • root 数据集的默认路径
  • transform 施加到一个PIL图像对象的变换
  • target_transform 施加到标签上的变换
  • loader 从给定路径加载图像的函数 默认为同文件folder.py中定义的default loader
  • is_valid_file 验证文件有效性的函数
class ImageFolder(DatasetFolder):

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
        self.imgs = self.samples

loder

folder.py提供了三种loader,分别是 default_loader, pil_loder和pil_loader,分别使用了两种读取图像的库accimage和PIL,对于PIL的简单介绍,也可以参照PIL Image 模块

def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


# TODO: specify the return type
def accimage_loader(path: str) -> Any:
    import accimage

    try:
        return accimage.Image(path)
    except OSError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path: str) -> Any:
    from torchvision import get_image_backend

    if get_image_backend() == "accimage":
        return accimage_loader(path)
    else:
        return pil_loader(path)
  • 23
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值