如何用MindSpore自定义数据集

引言

    在深度学习模型的训练过程中,数据集是起着至关重要作用的。然而,由于任务的复杂性,深度学习模型的输入数据也有着各种各样的形式,深度学习模型搭建的过程中,如果遇到特别复杂的数据,研究者可能要花费大半的时间在数据集的预处理(包括清洗、加载等过程)中。因此,高效的加载数据集,能给研究者构建一套高效的开发流程。
    使用过PyTorch的读者都知道,PyTorch框架为我们提供了一套极其便利且高效率的自定义数据加载的接口。用户只需要简单的继承torch.utils.data.Dataset并且在__get_item__函数和__len__函数,再利用Dataloader进行封装,就可以很简单的实现数据集的自动化加载流程(个人认为设置PyTorch在数据层面上做的超级好的一个点)。

如何用MindSpore自定义数据集

MindSpore数据集加载简介

    在MindSpore中,mindspore.dataset里面的函数为我们提供了大量的数据集专有加载算子,这些算子经过优化,拥有较好的数据集加载性能。但是,由于MindSpore本身的数据加载都是在C语言层面完成的,用户很难感知到内部进行的具体操作,特别是针对coco这一类较为复杂的数据集时(就是比较黑洞,很难自己掌握)。由于笔者是一个很喜欢把模型训练的每一步都抓在自己手里的一个人,因此除了cifar10、cifar100、imagefolder等经典的数据(结构)时,尽量都希望自己完成数据集的加载流程,以便更好的了解模型模型和数据集。因此,这篇博客将会主要介绍如何使用MindSpore自定自定义类似PyTorch范式的数据集加载流程。

mindspore.dataset.GeneratorDataset

    区别用PyTorch,MindSpore并不能像继承Dataset来完成数据集的构建,但是MindSpore为用户提供了一个类似于DataLoader的数据集封装接口。
    用户可以通过自定义object对象的数据集对象,然后使用GeneratorDataset进行封装,接下来我将以自定义cifar10和imagenet数据集来简单展示使用GeneratorDataset接口的方法

自定义cifar10数据集

分析格式

    在定义数据集之前,我们首先要做的就是数据集的格式分析。在cifar官网中,我们可以得知数据集的基本格式,还可以通过已有的博客,查看读取cifar10的代码样例。
如下图所示是cifar-10-batches-py数据集的目录文件,这里我们主要是关注data_batch和test_batch。
在这里插入图片描述

加载数据

    这里我主要以torchvision中的cifar10数据加载为例,说明构建cifar10数据集的方法。

    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    	...
        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list
        ...
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])
		"""可以很容易理解到,数据集文件里面有一个"data"和一个"label"键,分别拿出来就好"""
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

构建cifar10数据集并且完成预处理

    由于cifar10读取进来以后已经是数据形式,因此并不需要想用的图像解码,可以直接使用opencv或者PIL进行处理。这里以cifar10的test数据为例。

import os
import pickle
import numpy as np
import mindspore
from mindspore.dataset import GeneratorDataset


class CIFAR10(object):
    train_list = [
        'data_batch_1',
        'data_batch_2',
        'data_batch_3',
        'data_batch_4',
        'data_batch_5',
    ]

    test_list = [
        'test_batch',
    ]

    def __init__(self, root, train, transform=None, target_transform=None):
        super(CIFAR10, self).__init__()

        self.root = root
        self.train = train  # training set or test set
        
        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list
        self.data = []
        self.targets = []
        self.transform = transform
        self.target_transform = target_transform

        # now load the picked numpy arrays
        for file_name in downloaded_list:
            file_path = os.path.join(self.root, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])
    
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def __len__(self):
        return len(self.data)

cifar10_test = CIFAR10(root="./cifar10/cifar-10-batches-py", train=False)
cifar10_test = GeneratorDataset(source=cifar10_test, column_names=["image", "label"])
cifar10_test = cifar10_test.batch(128)
for data in cifar10_test.create_dict_iterator():
    print(data["image"].shape, data["label"].shape)

(128, 32, 32, 3) (128,)
(128, 32, 32, 3) (128,)
(128, 32, 32, 3) (128,)
(128, 32, 32, 3) (128,)

    可以从上面的代码看到,虽然语言风格不同,但是MIndSpore使用GeneratorDataset依然可以为我们提供一套相对便利的数据集加载方式。对于数据集的预处理的transform代码,研究者可以将代码直接通过transform参数传入__get_item__函数,十分方便;同时也可以使用mindspore语言风格,通过dataset自带的map函数,对数据集进行预处理,不过前者的语言风格更加python,推荐使用。

自定义ImageNet

分析格式

    接下来是介绍ImageNet的数据集自定义过程。其实定义ImageNet数据集加载器是非常方便的,因为图像分类的这类数据集往往是具有树状结构,我们只需要[路径,标签]或者是[图像,标签]的数组对传入到__get_item__函数中,就可以完成数据集的预处理。

数据加载

    这里就简单引用timm中定义folder的部分代码。

def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
    labels = []
    filenames = []
    for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
        rel_path = os.path.relpath(root, folder) if (root != folder) else ''
        label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
        for f in files:
            base, ext = os.path.splitext(f)
            if ext.lower() in types:
                filenames.append(os.path.join(root, f))
                labels.append(label)
    if class_to_idx is None:
        # building class index
        unique_labels = set(labels)
        sorted_labels = list(sorted(unique_labels, key=natural_key))
        class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
    images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
    if sort:
        images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
    return images_and_targets, class_to_idx

    可以看到,我们只需要遍历目录,得到images_and_target就好。

Mixup和Cutmix的使用

    在ImageNet中,我们常常会使用Mixup和Cutmix等数据增强,但是在对齐进行数据增强的时候,数据集已经是变成[batch_size, channel, height, width]形式出来的,在__get_item__进行数据预处理的函数是针对单个样本的。
    在PyTorch中,Mixup和Cutmix是在将数据取出,输入模型之前应用的。在mindspore中,我们只需要在使用dataset.batch函数之后再对数据集进行预处理。具体的代码可以参考我的博客如何用MindSpore实现自动数据增强,这里展示部分代码。

   if (mix_up > 0. or cutmix > 0.) and not is_training:
        # if use mixup and not training(False), one hot val data label
        one_hot = C.OneHot(num_classes=num_classes)
        dataset = dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers,
                              operations=one_hot)
    dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=num_parallel_workers)
    if (mix_up > 0. or cutmix > 0.) and is_training:
        mixup_fn = Mixup(
            mixup_alpha=mix_up, cutmix_alpha=cutmix, cutmix_minmax=None,
            prob=mixup_prob, switch_prob=switch_prob, mode=mixup_mode,
            label_smoothing=label_smoothing, num_classes=num_classes)

        dataset = dataset.map(operations=mixup_fn, input_columns=["image", "label"],
                              num_parallel_workers=num_parallel_workers)
    return dataset

FAQ

    自定义数据集的时候,千万要注意要重载__len__函数,没有这个函数,对象是无法感知数据集大小的。

总结

    本文介绍了如何使用GeneratorDataset这个接口自定义MindSpore数据集。虽然MindSpore为我们提供了好用的专有数据算子,但是由于数据加载在C语言层面完成,相对于torchvision来说存在着无法感知的缺陷,因此可以尝试使用GeneratorDataset自定义加载,把握每一步细节。(当然,其实也可以去torchvision搬代码拿GeneratorDataset封装就好~)

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
MindSpore提供了`mindspore.dataset`模块来处理数据集。你可以根据自己的数据集格式来创建数据集,并将其转换为MindSpore数据集。 以下是一个示例,假设你有一个文本分类数据集,其中包含一个文本文件和一个标签文件,每行文本文件包含一条数据,每行标签文件包含相应数据的标签。你可以使用以下代码将数据集转换为MindSpore数据集: ```python import mindspore.dataset as ds import mindspore.dataset.text as text # 定义数据集文件路径 data_file = "./data/text.txt" label_file = "./data/label.txt" # 定义数据集处理操作 data_ops = text.WhitespaceTokenizer() # 使用空格作为分词器 label_ops = text.ToNumber(output_type=ms.int32) # 将标签转换为整数类型 # 创建数据集 dataset = ds.TextFileDataset([data_file, label_file], num_samples=None) # 对数据集进行处理 dataset = dataset.map(operations=data_ops, input_columns=["text"]) dataset = dataset.map(operations=label_ops, input_columns=["label"]) # 打印数据集信息 print(dataset.output_shapes()) # 输出 [(None,), (None,)] print(dataset.output_types()) # 输出 [dtype('string'), dtype('int32')] ``` 在上面的代码中,我们首先定义了数据集文件的路径,然后定义了数据集处理操作。我们使用`text.WhitespaceTokenizer()`将文本文件中的每一行按空格进行分词,并使用`text.ToNumber()`将标签文件中的每一行转换为整数类型。接着,我们使用`ds.TextFileDataset()`读取文本文件和标签文件,并使用`ds.map()`对数据集进行处理。最后,我们使用`dataset.output_shapes()`和`dataset.output_types()`分别打印数据集的形状和类型信息。 你可以根据自己的数据集格式和需求,调整相应的数据集处理操作。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值