下面是内容只是简写
import mxnet as mx
import numpy as np
class Mix(mx.io.DataIter):
def __init__(self, paths, **kwargs):
super(Mix, self).__init__()
self.flags = range(len(paths))
self.numbers = {}.fromkeys(self.flags, 0)
self.generator = lambda x : mx.io.ImageRecorIter(x, **kwargs)
self.recs = map(self.generator, paths)
def reset(self):
for rec in self.recs:
rec.reset()
self.flags = range(len(self.recs))
def next(self):
while(self.flags):
idx = np.random.chioce(self.flags)
try:
self.batch = self.recs[idx].next()
expect StopIteration:
self.flags.remove(idx)
if not self.flags:
raise StopIteration
continue
break
.......