前言
Insightface-Data ZOO人脸识别算法里数据集都是mxnet格式,当我们使用pytorch训练时直接读取rec格式的人脸数据集可以大大加快数据读取速度。(群里面看到了并用在我自己的pytorch训练代码里,亲测效果挺好)
代码如下(示例):
import os
import numbers
import numpy
import cv2
import torch
import mxnet as mx
from torch.utils.data import Dataset
class MXFaceDataset(Dataset):
def __init__(self, root_dir, transform=None, train=True):
super(MXFaceDataset, self).__init__()
self.transform = transform
self.train = train
self.root_dir = root_dir
path_imgrec = os.path.join(root_dir, 'train.rec')
patn_imgidx = os.path.join(root_dir, 'train.idx')
self.imgrec = mx.recordio.MXIndexedRecordIO(patn_imgidx, path_imgrec, "r")
s = self.imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
if header.flag > 0:
print("header0 label:", header.label)
self.header0 = (int(header.label[0]), int(header.label[1]))
self.imgidx = list(range(1, int(header.label[0])))
# print(self.imgidx)
else:
self.imgidx = list(self.imgrec.keys)
print("Number of Samples:{} Number of Classes: {}".format(len(self.imgidx), int(self.header0[1] - self.header0[0])))
def __getitem__(self, index):
idx = self.imgidx[index]
s = self.imgrec.read_idx(idx)
header, img = mx.recordio.unpack(s)
label = header.label
if not isinstance(label, numbers.Number):
label = label[0]
label = torch.tensor(label, dtype=torch.long)
# print(label)
sample = mx.image.imdecode(img).asnumpy() # RGB
if self.transform is not None:
sample = self.transform(sample)
return sample, label
def __len__(self):
# print(len(self.imgidx))
return len(self.imgidx)
if __name__ == '__main__':
root_dir = '/datasets/train'
trainset = MXFaceDataset(root_dir)
num_dataset = len(trainset)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
for batch_idx, (sample, label) in enumerate(train_loader):
print(sample.shape, label)