导读
当我们需要使用大数据集来训练模型的时候,这时候加载和读取数据就成为了我们训练的瓶颈。Mxnet提供了一种数据的存储格式ImageRecord
,通过这种数据的存储格式可以使用多进程来加载数据,可以极大的提高训练的效率。
Mxnet提供了一个mx.recordio
模块,可以对ImageRecord
文件进行操作,该模块里面包含了两个子模块。MXRecordIO
支持顺序读写,MXIndexedRecordIO
支持随机读写。下面就来我们一起来看一下,这两个子模块中的一些接口函数,以及如何使用它们来读取和保存ImageRecord 文件
MXRecordIO
- 将数据保存为rec文件
在python3中需要先将字符串转为byte
数据之后才能写入到rec文件中,而python2中不需要做这一步处理
import mxnet as mx
write_record = mx.recordio.MXRecordIO("test.rec","w")
for i in range(5):
#向rec文件中写入数据
#将字符串数据转换为bytes数据,写入到文件中
write_record.write(str.encode("record_%d"%i))
write_record.close()
- 读取rec文件的内容
read_record = mx.recordio.MXRecordIO("test.rec","r")
while True:
item = read_record.read()
if not item:
break
#将bytes数据转换为字符串
print(item.decode())
read_record.close()
- 重新开始读取rec文件
rec文件的读取采用的是指针来读取的,当指针移到到文件的结束位置时,此时再读取文件的时候,会返回一个None
。如果你想要从头开始读取文件,可以通过reset
方法,将指针移动到文件的开始位置。
read_record = mx.recordio.MXRecordIO("test.rec","r")
i = 3
while True:
item = read_record.read()
if not item:
break
#将bytes数据转换为字符串
print(item.decode())
#再次读取内容
print(read_record.read())#None
#将指针移到文件的开始位置
read_record.reset()
#读取内容
print(read_record.read())#b'record_0'
read_record.close()
MXIndexedRecordIO
在训练模型的时候我们希望打乱数据集中数据的顺序,通过MXIndexedRecordIO
可以获取到一个随机顺序的数据
- 保存数据
import mxnet as mx
write_record = mx.recordio.MXIndexedRecordIO('test.idx', 'test.rec', 'w')
for i in range(5):
write_record.write_idx(i, str.encode('record_%d'%i))
write_record.close()
- 读取数据
read_record = mx.recordio.MXIndexedRecordIO('test.idx', 'test.rec', 'r')
for i in range(5):
item = read_record.read()
if item is None:
break
print(item.decode())
read_record.close()
- 获取随机顺序的数据
通过read_idx
我们可以移动指针的位置,来获取随机的数据
import random
idx_list = [i for i in range(0,4)]
#打乱idx的顺序
random.shuffle(idx_list)
read_record = mx.recordio.MXIndexedRecordIO('test.idx', 'test.rec', 'r')
for i in idx_list:
#获取指定idx的数据
read_record.read_idx(i)
item = read_record.read()
if item is None:
break
print(item.decode())
read_record.close()
- 获取文件中的idx
read_record = mx.recordio.MXIndexedRecordIO('test.idx', 'test.rec', 'r')
#通过items来获取idx
for key,_ in read_record.idx.items():
#key就是文件中的idx
print(key)
#获取文件的idx
print(read_record.keys)
二进制数据的保存和读取
在训练模型的时候,我们的数据格式大多是(data,label)
的形式来进行保存的,而其中的data可能是,可能是图片数据或者其他类型的数据,我们都可以将其转换为二进制数据来进行保存。
MXNet的每个rec
文件能够用来保存任意的二进制数据,而且mx.recordio
还提供了几个非常有用的函数用来打包
和解压
数据,pack
,unpack
, pack_img
, and unpack_img
字符串数据的保存和读取
#需要保存的字符串数据
data = "data"
#label可以是列表类型的数据
label1 = 1.0#[1.0,2.0,3.0]
#数据的id
id = 1
#打包数据的标签和id
header1 = mx.recordio.IRHeader(flag=0,label=label1,id=id,id2=0)
#将字符串转为二进制数据
s1 = mx.recordio.pack(header1,str.encode(data))
#解压数据
unpack_header1,unpack_s1 = mx.recordio.unpack(s1)
print(unpack_header1)#HEADER(flag=0, label=1.0, id=1, id2=0)
#获取数据的标签和id信息
print(unpack_header1.label,unpack_header1.id)#1.0 1
#获取压缩的数据
print(unpack_s1.decode())#data
图片数据的保存和读取
- 保存rec文件
import mxnet as mx
write_record = mx.recordio.MXIndexedRecordIO("test.idx","test.rec", 'w')
#读取图片
img_path = "img/test.jpg"
#将图片数据转为Numpy数组
img = mx.image.imread(img_path).asnumpy()
label = 1.0
header = mx.recordio.IRHeader(flag=0,label=label,id=0,id2=0)
s = mx.recordio.pack_img(header,img,quality=95,img_fmt=".jpg")
#将数据写入到rec文件中
write_record.write_idx(0,s)
write_record.close()
- 读取rec文件
read_record = mx.recordio.MXIndexedRecordIO("test.idx","test.rec","r")
#遍历rec文件
for idx in read_record.keys:
item = read_record.read()
#解压数据
header,s = mx.recordio.unpack(item)
#将图片的bytes数据转换为ndarray
img = mx.image.imdecode(s)
print(img.shape)
通过ImageRecordIter读取rec文件
from mxnet.io import ImageRecordIter
train_data = ImageRecordIter(
path_imgrec = "test.rec",
path_imgidx = "test.idx",
data_shape = (3, 112, 112),
batch_size = 1,
shuffle = True
)
for batch in train_data:
print(batch.data[0].shape, batch.label[0].shape)
break
参考: