pytorch图像数据集定义

本文介绍了如何在PyTorch中使用Dataset类定义数据集,如何应用transforms进行预处理,包括PILToTensor和Normalize,以及如何利用PytorchLightning和LightningDataModule进行更高级的模型训练和数据管理。着重讲解了ImageFolder和VisionDataset的用法,以及如何构建训练、验证和测试数据加载器。
摘要由CSDN通过智能技术生成


对于图像数据集来说,首先是在Dataset类对数据集进行定义,一般来说不定义transform,则数据为PIL Image,PIL格式到tensor的转换也是transforms变换的一种,所以定Dataset+transforms+Dataloader,最后在训练部分to(device)来得到模型的输入。

相关链接

torchvision.datasets的三个基础类
torchvision.datasets
torch.utils.data.Dataset
Pillow(PIL Fork) Image模块

Dataset

Dataset是数据集在pytorch中的化身,需要重写__ getitem__ 和 __ len__。__ __ getitem__ 通过传入的索引加载指定路径的数据,路径常常是一个列表,如很多张图片组成的数据集,需要在初始化时定义函数得到路径列表,或者在外部定义,总之要得到一个路径List。也需要在其中定义或调用具体读取的代码,如PIL库的Image.open()来读取图片,或Image.fromarray()来创建图片,也就是需要知道数据在哪里和怎么读取。

└─Dataset
    └─VisionDataset
        └─DatasetFolder
            └─ImageFolder

Dataset是torch.utils.data中的类,是数据集的基础类

VisionDataset

VisionDataset是torchvision.datasets.vision中的类,是torchvision类数据集的基础类,相比于原始的Dataset类,提供了transform,transforms,target_transform数据变换的接口

DatasetFolder,ImageFolder都来自torchvision.datasets.folder ,既然叫做folder,实际上已经有了完整的数据集功能,可以按照默认的目录结构读取数据。DatasetFolder还需要定义loader以读取特定类型的数据,和is_valid_file或者extensions,is_valid_file和extensions不能同时定义,但必须有一个定义,如果定义了有效后缀名,会自动通过后缀来判断文件有效性。而ImageFolder更进一步,默认使用读取图像数据的loader读取,还默认定义了图像后缀名。从Dataset到ImageFolder构成了不同层次的封装,完成度越高,灵活性越低,可以根据自己的需要选择。

除了在__ getitem__ 中通过得到的路径列表来读取数据,对于不同格式的数据也有不同的做法,如torchvision中内置cifar数据集,会直接从原始数据中以矩阵的形式读取, 因此 __ getitem__ 会从矩阵中创建Image对象。总而言之,一般来讲对于图片数据集来说,__ get __返回的都是PIL Image对象,不管是从路径列表中读取,还是整个以矩阵形式读取,如果不定义transform,最后在Dataset阶段都是PIL对象。

DatsetFolder

默认的排列结构如下,每一个文件夹表示一类,下面是这一类的样本

​      directory/

​      ├── class_x

​      │  ├── xxx.ext

​      │  ├── xxy.ext

​      │  └── ...

​      │    └── xxz.ext

​      └── class_y

​        ├── 123.ext

​        ├── nsdf3.ext

​        └── ...

​        └── asd932_.ext

用文件夹来区分不同的类别。比较重要的有两类操作,find_class函数得到类别名和类别序号。make_dataset得到路径列表。

默认的findclass函数

文件夹名是类名。

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

默认的make_dataset

得到instance列表,表示文件的路径列表。

基本上很大一部分是在定义有效性判断相关,主要部分是一个双层for循环,因为类名定义为文件夹名,所以会遍历各个类的文件夹,会将遍历到的有效文件的路径加入instance,遍历过的非空类添加到available_classe。

def make_dataset(
    directory: Union[str, Path],
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Union[str, Tuple[str, ...]]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
    allow_empty: bool = False,
) -> List[Tuple[str, int]]:
 
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = path, class_index
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes and not allow_empty:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances

ImageFolder

 root/dog/xxy.png
 root/dog/[...]/xxz.png

 root/cat/123.png
 root/cat/nsdf3.png

ImageFolder如名字所示,如果数据集是这种文件夹排列,而且是图像文件,又没有需要特殊定义的部分 ,可以直接实例化一个ImageFolder,而不需要重写任何部分 ,实例化一个数据集只需要传入数据集路径和tansform变换。

train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(
    os.path.join(data_dir, 'train_valid_test', folder),
    transform=transform_train) for folder in ['train', 'train_valid']]

torchvision.transforms

一般会在数据集实例化时,从外部传入,通常自定义的Transforms序列包含ToTensor,可以将上一阶段的PIL Image转换为Tensor,而下一次变化要到训练时的to(device),这样数据最终输入完成,也可以在Dataset类中写入默认的transform。

通过torchvision.get_image_backend得到torchvision现在的后端默认为PILtorchvision.set_image_backend(backend)指定用来读取图片的包,可选accimage

Loader将数据读取为PIL对象,一般数据集定义不在数据集内部定义默认的transform图像变换,而是在外部定义一个transform序列,通常倒数第二个是torchvision.transforms.ToTensor()操作,会将一个PIL Image或者一个ndarray转换为tensor并缩放到[0.0, 1.0]。因此接下来会通过transforms.Normalize进行归一化。

PILToTensor会把PIL Image转化为tensor,但是不会进行缩放, ( H × W × C ) → ( C × H × W ) (H\times W\times C)\rightarrow (C\times H \times W) (H×W×C)(C×H×W)

ToTensor会把PIL Image或者ndarray转换成tensor而且会进行缩放。 ( H × W × C ) → ( C × H × W ) (H\times W\times C)\rightarrow (C\times H \times W) (H×W×C)(C×H×W)​ 在规定的模式如RGBA,RGB,YCbCr或者dtype = np.uint8情况下,别的情况下不缩放。

Normalize只支持tensor,其他大部分操作也支持PIL,所以在ToTensor之后最后进行Normalize

data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

Pytorch Lightning

Pytorch Lightning是Pytorch中的kersas,简称pl

LightningDataModule

Pytorch Lightning继承LightningDataModule定义数据集,pl中的Dataset和Dataloader是高度耦合的。

import lightning.pytorch as L
import torch.utils.data as data
from pytorch_lightning.demos.boring_classes import RandomDataset

        class MyDataModule(L.LightningDataModule):
            def prepare_data(self):
                # download, IO, etc. Useful with shared filesystems
                # only called on 1 GPU/TPU in distributed
                ...

            def setup(self, stage):
                # make assignments here (val/train/test split)
                # called on every process in DDP
                dataset = RandomDataset(1, 100)
                self.train, self.val, self.test = data.random_split(
                    dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
                )

            def train_dataloader(self):
                return data.DataLoader(self.train)

            def val_dataloader(self):
                return data.DataLoader(self.val)

            def test_dataloader(self):
                return data.DataLoader(self.test)

            def teardown(self):
                # clean up state after the trainer stops, delete files...
                # called on every process in DDP
                ...
  • 24
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值