mxnet 数据部分接口解读和可视化

(一)数据读取

mxnet 参考api 文档:http://mxnet.incubator.apache.org/api/python/io/io.html

这里主要用到两个函数接口:io.ImageRecordIter以及recordio.MXIndexedRecordIO。

  • io.ImageRecordIter:

函数解释:从rec文件中读取数据,根据batchSize的大小读取数据。

函数参数:(几个重要的参数,其余的参考文档)

                path_imgrec :rec文件地址

                data_shape:图像输出大小,形式为(channel,height,width)。如果与原图不同,会进行裁剪。

                label_width:一个图像对应的标签数,多属性时会用到,默认为1。

                batch_size:一次性读取的数据大小。

                shuffle:是否在读取时进行混洗,随机读取。 

注意点:

           1.当shuffle设置为false时,读取的顺序是按照lst文件的顺序进行读取的。在进行训练的时候,最好是设置为true,对于训练有提升效果。

            2.从这个迭代器中取数据,是调用next()函数。

  • recordio.MXIndexedRecordIO

函数解释:与上个相比,该函数接口支持数据的随机读取。

函数参数:idx文件路径和rec文件路径。

注意点:通过tools/im2rec.py文件生成数据的时候,一般会有三个文件,后缀为.lst,.idx,.rec。一般用到的是lst文件和rec文件。idx文件一般被人忽略,其实idx文件的内容是决定了该接口可以随机读取的关键。

idx文件内容为两列,一列为文件的index,对应lst文件的第一列。idx文件第二列是编码位置,也就是该文件在rec中的存储位置段。

从该接口读取数据的时候,是按照索引读取的,索引数就是原始图像在原始文件夹的位置,即lst未打乱前的从index=0,label=0开始索引,按index索引。当然,有了索引,那就可以随机读取图片。

(二)读取代码示例

    path_idx='train_webface_train.idx'
    path_rec='train_webface_train.rec'
    path_lst='train_webface_train.lst'
    ## 1st
    train = mx.io.ImageRecordIter(path_imgrec=path_rec,batch_size=1,data_shape = (1,144,128))
    train.reset()
    #data
    data = next(train)
    image=data.data[0][0][0].asnumpy()
    
    ## 2nd
    imrec = mx.recordio.MXIndexedRecordIO(path_idx, path_rec, 'r')
    #data
    index = 0
    s = imrec.read_idx(index)
    header, img = mx.recordio.unpack(s)
    image1 = mx.image.imdecode(img, 0)
    image1 = mx.nd.transpose(image1, axes=(2, 0, 1))[0].asnumpy()


(三)数据可视化

在上面的代码中,注意从data.data是个list。list内的ndarray是4维的数据,所以需要多层索引。

而从imrec中读取的数据需要先解码,之后解码的图像数据的维度是(h,w,channel),所以需要先转换维度。

最后将ndarray转换为numpy,用matplot画图。

train=mx.io.ImageRecordIter(path_imgrec=path_rec,batch_size=1,data_shape=(1,height,width))
data=next(train)

## 单通道画图
image=data.data[0][0][0].asnumpy()
import matplotlib.pyplot as plt
plt.imshow(image,cmap='gray')
plt.title('image')
plt.show()

## 3通道画图
image=data.data[0][0].asnumpy().astype(np.uint8).transpose((1,2,0))
import matplotlib.pyplot as plt
plt.imshow(image)
plt.title('image')
plt.show()




               

 

 

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值