MindSpore系列一加载图像分类数据集

        MindSpore提供了大部分常用数据集和标准格式数据集的加载接口,可以直接使用mindspore.dataset中对应的数据集加载类进行数据加载,如MNIST、CIFAR-10、CIFAR-100、VOC、COCO、ImageNet、CelebA、CLUE等, 以及业界标准格式的数据集,包括MindRecord、TFRecord、Manifest等。

常用数据集加载以cifar10为例,首先将cifar10数据集下载并解压到本地。

1、加载cifar10数据集:

DATA_DIR = "./cifar-10-batches-bin/"
sampler = ds.SequentialSampler(num_samples=5)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)

create_dict_iterator创建数据迭代器访问数据:

for data in dataset.create_dict_iterator():
    print("Image shape: {}".format(data['image'].shape), ", Label: {}".format(data['label']))

2、加载自定义图像分类数据集

         使用mindspore加载自定义图像分类数据,可以使用mindspore.dataset.ImageFolderDataset接口进行加载。将相同类别的图像放在同一文件夹下,不同类别以不同文件夹区分,将所有分类的上级目录传入ImageFolderDataset接口,mindspore会自动加载图像数据并根据不同文件夹分配对应标签。

 

         这里以TinyImageNet为例进行数据加载。首先,使用imageFolderDataset接口传入数据路径,通过num_parallel_worker可设置数据加载并行线程数,shuffle参数设置是否打乱数据顺序。另外需要通过map接口进行图像数据预处理,图像预处理接口mindspore.dataset.vision.c_transforms,通过c_transforms可进行图像解码,缩放归一化,矩阵转置等操作。

import mindspore
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import dtype as mstype

def create_dataset(data_path, batch_size=24, c_transforms
 repeat_num=1):
    """定义数据集"""
    parallel_mode = context.get_auto_parallel_context("parallel_mode")
    if parallel_mode == context.ParallelMode.DATA_PARALLEL:
        data_set = ds.ImageFolderDataset(data_path, num_parallel_def create_dataset(data_path, batch_size=24, repeat_num=1):
    """定义数据集"""
    data_set = ds.ImageFolderDataset(data_path, num_parallel_workers=8, shuffle=True)
    image_size = [100, 100]
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    trans = [
        CV.Decode(),
        CV.Resize(image_size),
        CV.Normalize(mean=mean, std=std),
        CV.HWC2CHW()
    ]

    # 实现数据的map映射、批量处理和数据重复的操作
    type_cast_op = C.TypeCast(mstype.int32)
    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
    data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    data_set = data_set.batch(batch_size, drop_remainder=True)
    data_set = data_set.repeat(repeat_num)

    return data_set

        数据迭代,ImageFolderDataset通过create_tuple_iterator()接口对数据集进行迭代,每次迭代一个batch的数据。

if __name__ == '__main__':
    datapath = 'D:/Sources/Data/datasets/TinyImageNet/val'
    ds = create_dataset(datapath, batch_size=8)
    iterator = ds.create_tuple_iterator()
    for item in iterator:
        print(f'images:{mindspore.Tensor(item[0]).shape},labels:{item[1]}')

 

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TheMatrixs

你的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值