前言
之前在GPU集群上配置的caffe因为一系列人为因素崩溃,搭建的Tensorflow由于cuda版本太低有些实验不能跑,
而恰逢管理员不在,只好找一款不受这些因素影响的框架,之前阅读过mxnet源码,源码不多,很容易懂。于是配置
了一把,成功只好用了。而我的实验基于多源数据,即包含两种输入,这里是face images和audio数据,如果单纯使用
官方提供DataIter不能够完成任务,只好自己写(由于数据较大,直接使用NDArrayIter不现实,不如直接自己重新设计
一种DataIter),当然本文章着重讲解如何自定义DataIter,细节还需参看源码。
话不多说,上干货。
目录
Mxnet中的DataIter
DataIter
NDArrayIter
MXDataIter
自定义DataIter
Mxnet中的DataIter
DataIter类对象都在模块io.py
中,而所有的DataIter都继承于基类DataIter
,其中DataIter
源码如下:
class DataIter(object):
def __init__(self):
self.batch_size = 0
def __iter__(self):
return self
def reset(self):
pass
def next(self):
if self.iter_next():
return mx.io.DataBatch(data=self.getdata(), label=self.getlabel(), \
pad=self.getpad(), index=self.getindex())
else:
raise StopIteration
def __next__(self):
return self.next()
def iter_next</