MXNet学习[4]MXNet Data Iterator
文章目录
1、MXNet Data Iterator(数据迭代器)
本文先就DataBatch、DataDesc、DataIter三个主要用到的类进行介绍,然后引出Mxnet中常见的迭代器。最后介绍一种为通用数据格式设计的数据迭代器DataLoaderIter。
2、DataBatch
2.1、介绍
MXNet中的数据迭代器Data iterator类似于Python迭代器对象。在Python中,函数iter()允许通过对可iterable对象(如Python列表)调用next()按顺序获取item。数据迭代器提供了一个抽象接口,用于遍历各种类型的iterable集合,而无需公开底层数据源的详细信息。
在数据流结束时,当没有更多的数据可读取时,迭代器会引发像Python iter那样的StopIteration异常。
作用:
- 在MXNet中,数据迭代器在每次调用next()时,返回一批数据作为DataBatch,DataBatch通常包含n个训练示例及其相应的标签,这里的n是Batch的大小。
- 如果输入的数据是图像的话,这些图像的shape取决于DataDesc中的provide_data参数。每个Batch的训练样本的名称、形状、类型和Layout等信息构成的DataDesc数据描述符及其相应的标签可以通过DataBatch中的provide_data和provide_label属性提供。
- 如果将布局(layout)设置为“ NCHW”,则图像应以形状(batch_size,num_channel,height,width)的4维矩阵存储。 如果将布局(layout)设置为“ NHWC”,则图像应存储在形状(batch_size,height,width,num_channel)的4-D矩阵中。 通道通常为RGB顺序。
2.2、DataBatch类
class mxnet.io.DataBatch(data, label=None, pad=None, index=None, bucket_key=None, provide_data=None, provide_label=None)
参数:
- data:一个关于NDArray的列表,每个NDArray都包含了bach_size个大小的样本。a list of input data
- label:一个关于NDArray的列表,每个NDArray都包含了一维的标签信息。a list of input labels
- pad: 整型,可选。在最后一个batch时,填充的样本数。当读取的样本总数不能被批大小整除时使用。这些额外的填充样本在预测中被忽略。
- index:numpy数组格式,可选。该批量中样本的索引
- bucket_key:整型,可选。The bucket key, used for bucketing module.
- provide_data:一个关于DataDesc的列表,可选。DataDesc用于存储数据的名字,形状,类型和layout信息。第i个元素描述了data[i]的名字和形状。
- provide_label:一个关于DataDesc的列表,可选。DataDesc用于存储数据的名字,形状,类型和layout信息。第i个元素描述了label[i]的名字和形状。
3、DataDesc
3.1、介绍
作用
- DataDesc用于存储数据的名字,形状,类型和格式信息。
- 每个Batch训练样本的名称、形状、类型和Layout等信息及其相应的标签可以通过DataBatch中的provide_data和provide_label属性提供,而DataDesc数据描述符对象包含了数据的名字,形状,类型和layout信息。
- Layout描述了shape中轴的解释方式,例如,对于图像数据,layout = NCHW表示第一个轴是Batch(N)中的样本数,C是通道数,H是高度,W是 图片的宽度。对于顺序数据sequential --data,默认情况下Layout设置为NTC,其中N是Batch中的样本数,T是代表时间的时间轴,C是通道数。
class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
参数
- cls(DataDesc):类自己
- name:字符串,数据名字
- shape:元组或整型,数据形状
- dtype:nd.dtype 可选。数据类型
- layout:字符串,可选。数据格式。包括 NCHW\NHWC
方法
- get_batch_axis(layout):获取与批处理大小相对应的维度。
- get_list(shapes, types):从属性列表中获取DataDesc列表.
4、DataIter
4.1、介绍
作用
- DataIter是mxnet中数据迭代器dataiter的基类(The base class for an MXNet data iterator)
- mxnet中所有的数据IO都由mx.io.DataIter的子类来处理。mxnet中的dataiter迭代器是和python中的iterators很像,每次调用nxet都会返回一个Databatch,代表了一个批量中的数据,当没有更多的数据返回时,它会引发StopIteration异常。
- MXnet中的Data Iterator和python中的迭代器是很相似的, 当Data Iterator的内置方法next()被调用的时候,它每次返回一个 DataBatch。所谓DataBatch,就是网络所需要的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。
4.2、DataIter类
class mxnet.io.DataIter(batch_size=0)
参数
- batch_size:the batch size,即批次中的项目数量。
方法 - getdata():获取当前批次的数据。
- getindex():获取当前批的索引。
- getlabel():获取当前批次的标签。
- getpad():获取当前批处理中的填充样本数。
- iter_next():移到下一批。
- next():从迭代器获取下一个数据批。
- reset():将迭代器重置为数据的开头
5、Data iterators:Mxnet中所有常用的迭代器
MXNet中的所有IO都通过mx.io.DataIter以及它的子类来处理,以下是常见的几种迭代器。
6、Custom Iterator(自己定制一个迭代器)
当所有内置的迭代器不能满足时,可以定制。mxnet中的迭代器应当满足:
- 如果是py2应实现next(),py3的话应实现__next()__,并返回一个DataBatch或升起一个StopIteration意外当迭代到最后的时候。
- 实现reset()方法来返回到迭代器头部
- 实现provide_data属性
- 实现provide_label属性
创建新的迭代器时,可以从头开始定义迭代器,也可以重用现有迭代器,例如,在图像caption应用程序中(看图说话),输入示例是图像,而标签是句子。可以通过以下方法创建新的迭代器:
- 使用ImageRecordIter创建一个image_iter,它提供多线程预取和扩充。
- 使用rnn包中提供的NDArrayIter或bucketing迭代器创建caption_iter。
- next()返回image_iter.next()和caption_iter.next()
6.1、定制一个迭代器
class SimpleIter(mx.io.DataIter):# MXNet中的所有IO都通过mx.io.DataIter以及它的子类来处理
def __init__(self, data_names, data_shapes, data_gen,
label_names, label_shapes, label_gen, num_batches=10):
self._provide_data = list(zip(data_names, data_shapes))
self._provide_label = list(zip(label_names, label_shapes))
self.num_batches = num_batches
self.data_gen = data_gen
self.label_gen = label_gen
self.cur_batch = 0
def __iter__(self):
return self
def reset(self):
self.cur_batch = 0
def __next__(self):
return self.next()
@property
def provide_data(self):
return self._provide_data
@property
def provide_label(self):
return self._provide_label
def next(self):
if self.cur_batch < self.num_batches:
self.cur_batch += 1
data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
return mx.io.DataBatch(data, label)
else:
raise StopIteration
- 上面的代码是最简单的一个DataIter,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个DataIter必须要实现上面的几个方法
- provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height)
- reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的
- next()的方法就很显然了,用来返回你的DataBatch,如果出现问题…记得raise stopIteration,这里或许用try更好吧…需要注意的是,DataBatch返回的数据类型是mx.nd.ndarry。
转载:https://blog.csdn.net/qq_35091353/article/details/108759726