MXNet学习[4]MXNet Data Iterator

MXNet学习[4]MXNet Data Iterator

1、MXNet Data Iterator(数据迭代器)

本文先就DataBatch、DataDesc、DataIter三个主要用到的类进行介绍,然后引出Mxnet中常见的迭代器。最后介绍一种为通用数据格式设计的数据迭代器DataLoaderIter。

2、DataBatch

2.1、介绍

  MXNet中的数据迭代器Data iterator类似于Python迭代器对象。在Python中,函数iter()允许通过对可iterable对象(如Python列表)调用next()按顺序获取item。数据迭代器提供了一个抽象接口,用于遍历各种类型的iterable集合,而无需公开底层数据源的详细信息。
在数据流结束时,当没有更多的数据可读取时,迭代器会引发像Python iter那样的StopIteration异常。

作用:

  • 在MXNet中,数据迭代器在每次调用next()时,返回一批数据作为DataBatch,DataBatch通常包含n个训练示例及其相应的标签,这里的n是Batch的大小。
  • 如果输入的数据是图像的话,这些图像的shape取决于DataDesc中的provide_data参数。每个Batch的训练样本的名称、形状、类型和Layout等信息构成的DataDesc数据描述符及其相应的标签可以通过DataBatch中的provide_data和provide_label属性提供。
  • 如果将布局(layout)设置为“ NCHW”,则图像应以形状(batch_size,num_channel,height,width)的4维矩阵存储。 如果将布局(layout)设置为“ NHWC”,则图像应存储在形状(batch_size,height,width,num_channel)的4-D矩阵中。 通道通常为RGB顺序。

2.2、DataBatch类

class mxnet.io.DataBatch(data, label=None, pad=None, index=None, bucket_key=None, provide_data=None, provide_label=None)

参数:

  • data:一个关于NDArray的列表,每个NDArray都包含了bach_size个大小的样本。a list of input data
  • label:一个关于NDArray的列表,每个NDArray都包含了一维的标签信息。a list of input labels
  • pad: 整型,可选。在最后一个batch时,填充的样本数。当读取的样本总数不能被批大小整除时使用。这些额外的填充样本在预测中被忽略。
  • index:numpy数组格式,可选。该批量中样本的索引
  • bucket_key:整型,可选。The bucket key, used for bucketing module.
  • provide_data:一个关于DataDesc的列表,可选。DataDesc用于存储数据的名字,形状,类型和layout信息。第i个元素描述了data[i]的名字和形状。
  • provide_label:一个关于DataDesc的列表,可选。DataDesc用于存储数据的名字,形状,类型和layout信息。第i个元素描述了label[i]的名字和形状。

3、DataDesc

3.1、介绍

作用

  • DataDesc用于存储数据的名字,形状,类型和格式信息。
  • 每个Batch训练样本的名称、形状、类型和Layout等信息及其相应的标签可以通过DataBatch中的provide_data和provide_label属性提供,而DataDesc数据描述符对象包含了数据的名字,形状,类型和layout信息。
  • Layout描述了shape中轴的解释方式,例如,对于图像数据,layout = NCHW表示第一个轴是Batch(N)中的样本数,C是通道数,H是高度,W是 图片的宽度。对于顺序数据sequential --data,默认情况下Layout设置为NTC,其中N是Batch中的样本数,T是代表时间的时间轴,C是通道数。
class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):

参数

  • cls(DataDesc):类自己
  • name:字符串,数据名字
  • shape:元组或整型,数据形状
  • dtype:nd.dtype 可选。数据类型
  • layout:字符串,可选。数据格式。包括 NCHW\NHWC

方法

  • get_batch_axis(layout):获取与批处理大小相对应的维度。
  • get_list(shapes, types):从属性列表中获取DataDesc列表.

4、DataIter

4.1、介绍

作用

  • DataIter是mxnet中数据迭代器dataiter的基类(The base class for an MXNet data iterator)
  • mxnet中所有的数据IO都由mx.io.DataIter的子类来处理。mxnet中的dataiter迭代器是和python中的iterators很像,每次调用nxet都会返回一个Databatch,代表了一个批量中的数据,当没有更多的数据返回时,它会引发StopIteration异常。
  • MXnet中的Data Iterator和python中的迭代器是很相似的, 当Data Iterator的内置方法next()被调用的时候,它每次返回一个 DataBatch。所谓DataBatch,就是网络所需要的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。

4.2、DataIter类

class mxnet.io.DataIter(batch_size=0)

参数

  • batch_size:the batch size,即批次中的项目数量。
    方法
  • getdata():获取当前批次的数据。
  • getindex():获取当前批的索引。
  • getlabel():获取当前批次的标签。
  • getpad():获取当前批处理中的填充样本数。
  • iter_next():移到下一批。
  • next():从迭代器获取下一个数据批。
  • reset():将迭代器重置为数据的开头

5、Data iterators:Mxnet中所有常用的迭代器

MXNet中的所有IO都通过mx.io.DataIter以及它的子类来处理,以下是常见的几种迭代器。
在这里插入图片描述

6、Custom Iterator(自己定制一个迭代器)

当所有内置的迭代器不能满足时,可以定制。mxnet中的迭代器应当满足:

  • 如果是py2应实现next(),py3的话应实现__next()__,并返回一个DataBatch或升起一个StopIteration意外当迭代到最后的时候。
  • 实现reset()方法来返回到迭代器头部
  • 实现provide_data属性
  • 实现provide_label属性

创建新的迭代器时,可以从头开始定义迭代器,也可以重用现有迭代器,例如,在图像caption应用程序中(看图说话),输入示例是图像,而标签是句子。可以通过以下方法创建新的迭代器:

  • 使用ImageRecordIter创建一个image_iter,它提供多线程预取和扩充。
  • 使用rnn包中提供的NDArrayIter或bucketing迭代器创建caption_iter。
  • next()返回image_iter.next()和caption_iter.next()

6.1、定制一个迭代器

class SimpleIter(mx.io.DataIter):# MXNet中的所有IO都通过mx.io.DataIter以及它的子类来处理
    def __init__(self, data_names, data_shapes, data_gen,
                 label_names, label_shapes, label_gen, num_batches=10):
        self._provide_data = list(zip(data_names, data_shapes))
        self._provide_label = list(zip(label_names, label_shapes))
        self.num_batches = num_batches
        self.data_gen = data_gen
        self.label_gen = label_gen
        self.cur_batch = 0

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch < self.num_batches:
            self.cur_batch += 1
            data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
            label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
            return mx.io.DataBatch(data, label)
        else:
            raise StopIteration
  • 上面的代码是最简单的一个DataIter,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个DataIter必须要实现上面的几个方法
  • provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height)
  • reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的
  • next()的方法就很显然了,用来返回你的DataBatch,如果出现问题…记得raise stopIteration,这里或许用try更好吧…需要注意的是,DataBatch返回的数据类型是mx.nd.ndarry。

转载:https://blog.csdn.net/qq_35091353/article/details/108759726

相关推荐
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页