如何用MindSpore实现自动数据增强

如何用MindSpore实现自动数据增强

引言

    在深度学习训练的过程中,数据增强有着十分重要的作用。在目前模型设计的工作中,timm库被研究者们广泛使用,其重要的原因之一就是timm库提供了一套非常完备的深度学习工作流程(特别是在数据增强方面),这一套完备的流程可以让模型设计工作者们专注于模型本身的设计,而不用去关心其他复杂的模型训练流程。
    在之前的博客利用MindSpore复现ICCV2021 Best Paper Swin Trasnformer中,我使用MindSpore复现了SwinTransformer,其中遇到最大的难点便是MindSpore本身自带的数据增强库的接口逻辑和timm中的较大差别。当使用MindSpore自带的数据增强库时,虽然可以获得较好的性能,但是由于实现流程的差别,导致我很难复现到作者论文的精度。同时在我自己使用MindSpore进行科研任务、模型编写的过程中,我也经常受到MindSpore数据处理流程不如timm方便的困扰,因此便萌生了一个想法:

能否为MindSpore也提供一套timm形式的数据增强

实现过程

    通过阅读timm和MindSpore的代码,我们可以很简单的发现,其实MindSpore和timm的图像数据库无非都是基于PIL和opencv的,而这两者其实都可以通过numpy、Tensor等框架接口实现非常简单的转换,因此其实我们只要简单的将基于PyTorch、torchvision的代码迁移到numpy和mindspore.Tensor就好,就可以完成MindSpore实现数据增强的效果。

代码使用

代码我已经开源在我的github仓库中,整体的代码基本都是抄自timm仓库的,我只是一个代码的搬运工。对于timm仓库中一些不常用的接口,因为考虑到MindSpore本身的数据下沉已经可以解决那些数据流和模型训练并行的情况,比如use_prefetch等我都进行了一些简单的删除。本篇博客只是介绍一下这个小工具的简单实用。

创建数据

    在创建数据迭代器的流程中,我并没有对timm的接口进行太大的修改,在create_imagenet_loader函数中,数据增强方面,我们可以如timm使用那样一些合适的参数即可。
其中改变的几个参数如下:

  • root:给定数据集的地址,目前只是支持imagenet类的树状摆放的数据集,可以需要给到***/train或者***/val等数据目录即可
  • data_type:目前数据集支持两类数据格式,类似于imagenet的原生数据image类别的,或者使用mindspore自带的接口将其转化为mindrecord的。关于如何将imagenet转成mindrecord,大家可以参考代码imagenet_to_mindrecord.py
def create_imagenet_loader(
        root,
        batch_size,
        input_size,
        num_classes,
        is_training=False,
        no_aug=False,
        re_prob=0.,
        re_mode='const',
        re_count=1,
        re_split=False,
        scale=None,
        ratio=None,
        hflip=0.5,
        vflip=0.,
        color_jitter=0.4,
        auto_augment=None,
        num_aug_splits=0,
        interpolation='bilinear',
        cutmix=0.,
        mix_up=0.,
        mixup_prob=0.,
        switch_prob=0.,
        mixup_mode="batch",
        label_smoothing=0.1,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        crop_pct=None,
        num_parallel_workers=8,
        data_type="image",
):

    通过简单阅读mindspore.mindrecord.ImageNetToMR函数,我们就可以了解到实际上就是简单的把路径、图像、标签变成一个字典然后以mindrecord保存下来的。

        imagenet_schema_json = {"label": {"type": "int32"},
                                "image": {"type": "bytes"},
                                "file_name": {"type": "string"}}                               

    加载mindrecord数据集的时候,如果保存了多个文件,我们只需要加载结尾是0的那个文件,函数会自动加载这个是全部数据的(记得到时候把train和val数据集分开放置在不同的文件夹下面

在这里插入图片描述

def imagenet_mind(dataset_dir, num_parallel_workers, shuffle):
    num_shards, rank_id = _get_rank_info()
    files = os.listdir(dataset_dir)
    data_file = list(filter(lambda x: not x.endswith(".db"), files))
    if len(data_file) == 1:
        data_file = data_file[0]
    else:
        data_file = list(filter(lambda x: x.endswith("0"), data_file))[0]
    dataset = MindDataset(data_file, num_parallel_workers=num_parallel_workers,
                          shuffle=shuffle, columns_list=["image", "label"],
                          num_shards=num_shards, shard_id=rank_id)
    return dataset

    关于如何使用ImageFolder,这里就不再赘述~

结尾

    本文就是简单介绍一下为MindSpore迁移的timm数据处理Pipeline,相当于为MindSpore弄了一个第三方的数据增强小套件,希望可以为大家开发降低一些困难。
    最后希望大家如果感觉这个仓库好用的话,给个小星星哦~
https://github.com/Holidays1999/ImgeNet_MindSpore_Pipline

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值