1.使用ImageRecordIter读取rec
mxnet.io.ImageRecordIter(*args, **kwargs)
该方式只能从rec文件读取batches,相比于定制化的输入方式,该方式不够灵活,但是速度很快。如果要读取原图,可以使用ImageIter
eg:
data_iter = mx.io.ImageRecordIter(
path_imgrec="./sample.rec", # The target record file.
data_shape=(3, 227, 227), # Output data shape; 227x227 region will be cropped from the original image.
batch_size=4, # Number of items per batch.
resize=256 # Resize the shorter edge to 256 before cropping.
# You can specify more augmentation options. Use help(mx.io.ImageRecordIter) to see all the options.
)
# You can now use the data_iter to access batches of images.
batch = data_iter.next() # first batch.
images = batch.data[0] # This will contain 4 (=batch_size) images each of 3x227x227.
# process the images
...
data_iter.reset() # To restart the iterator from the beginning.
参数中可以指定augmentation的各种操作具体的参数可以参考
http://mxnet.incubator.apache.org/versions/master/api/python/io/io.html?highlight=record
1.mxnet.image.ImageIter读取rec或者原图
class mxnet.image.ImageIter(
batch_size,
data_shape, #只支持3通道RGB
label_width=1,
path_imgrec=None,
path_imglist=None,
path_root=None,
path_imgidx=None,
shuffle=False,
part_index=0,
num_parts=1,
aug_list=None,
imglist=None,
data_name ='data',
label_name ='softmax_label',
dtype='float32',
last_batch_handle='pad',
**kwargs
)
这是一个带有大量augmentation操作的data iterator,它支持从.rec文件或者原始图片读取数据
使用path_imgrec
参数load .rec
文件,使用path_imglist
参数load原始图片数据。
通过指定path_imgidx
参数使用数据分布式训练或者shuffling
参考
http://mxnet.incubator.apache.org/versions/master/api/python/image/image.html#mxnet.image.ImageIter
https://blog.csdn.net/u014380165/article/details/74906061
一个使用的例子
data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227),
path_imgrec="./data/caltech.rec",
path_imgidx="./data/caltech.idx" )
# data_iter的类型是mxnet.image.ImageIter
#reset()函数的作用是:resents the iterator to the beginning of the data
data_iter.reset()
#batch的类型是mxnet.io.DataBatch,因为next()方法的返回值就是DataBatch
batch = data_iter.next()
#data是一个NDArray,表示第一个batch中的数据,因为这里的batch_size大小是4,所以data的size是4*3*227*227
data = batch.data[0]
#这个for循环就是读取这个batch中的每张图像并显示
for i in range(4):
plt.subplot(1,4,i+1)
plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()
使用mx.image.CreateAugmenter()
进行图像augmentation