MxNet框架下的数据读取
通常,我们在做深度学习训练的时候数据读取是必不可少的,下面我们介绍两种方法:
一、直接读取文件夹下的数据不转换成.rec和.idx
常用来做分类等,train文件夹下放置的数据为每个类别一个文件夹,test文件夹和train类似。
主要使用函数:
gluon.data.vision.ImageFolderDataset
from mxnet import gluon
from mxnet import nd
from mxnet import image
import sys
sys.path.append('..')
import mxnet as mx
import time
from mxnet.gluon import data as gdata, loss as gloss, nn, utils as gutils
from mxnet import autograd as autograd
from mxnet import nd
from mxnet import ndarray as nd
data_dir = '/home/zn/project/mxnet_project/data'
train_augs = [
image.HorizontalFlipAug(0.5),
#image.imresize(224,224)
]
test_augs = [
#image.imresize((224,224))
]
#image transform
def transform(data,label,augs):
data = image.imresize(data,224,224)
data = data.astype('float32')
for aug in augs:
data = aug(data)
data = nd.transpose(data,(2,0,1))
return data, nd.array([label]).asscalar().astype('float32')
#imshow
import matplotlib.pyplot as plt
def show_images(imgs,nrows,ncols,figsize=None):
if not figsize:
figsize=(ncols,nrows)
_,figs = plt.subplots(nrows,ncols,figsize=figsize)
for i in range(nrows):
for j in range(ncols):
figs[i][j].imshow(imgs[i*ncols+j].asnumpy())
figs[i][j].axes.get_xaxis().set_visible(False)
figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()
#数据读取重点部分
train_imgs = gluon.data.vision.ImageFolderDataset(data_dir+'/train',
transform = lambda X,y:transform(X,y,train_augs))
test_imgs = gluon.data.vision.ImageFolderDataset(data_dir+'/test',
transform = lambda X,y:transform(X,y,test_augs))
data = gluon.data.DataLoader(train_imgs,32,shuffle=True)
for X,_ in data:
X = X .transpose((0,2,3,1)).clip(0,255)/255
print(X.shape,X[1,100,100,1])
show_images(X,4,8)
break
二、数据读取方式采用的是将数据转成.idx 和 .rec格式来使用
/mxnet/tools/im2rec.py
from mxnet import image
from mxnet import nd
data_shape = 256
batch_size = 32
rgb_mean = nd.array([123,117,104])
def get_iterators(data_shape,batch_size):
class_names = ['pikachu']
num_class = len(class_names)
train_iter = image.ImageDetIter(
batch_size = batch_size,
data_shape=(3,data_shape,data_shape),# 输出图像的形状
path_imgrec =data_dir + 'train.rec',
path_imgidx = data_dir + 'train.idx',
shuffle = True, # 以随机顺序读取数据集
mean = True,
rand_crop = 1, # 随机裁剪的概率为1
min_object_covered=0.95,
max_attempts=200
)
val_iter = image.ImageDetIter(
path_imgrec=os.path.join(data_dir,'val.rec'),
batch_size=batch_size,
data_shape = (3,data_shape,data_shape),
shuffle=False,
mean=True
)
return train_iter,val_iter,class_names,num_class
train_data,test_data,class_names,num_class=get_iterators(data_shape,batch_size)