Pytorch读取Mxnet的rec格式数据

前言

Insightface-Data ZOO人脸识别算法里数据集都是mxnet格式,当我们使用pytorch训练时直接读取rec格式的人脸数据集可以大大加快数据读取速度。(群里面看到了并用在我自己的pytorch训练代码里,亲测效果挺好)

代码如下(示例):

import os
import numbers
import numpy
import cv2
import torch
import mxnet as mx
from torch.utils.data import Dataset


class MXFaceDataset(Dataset):
    def __init__(self, root_dir, transform=None, train=True):
        super(MXFaceDataset, self).__init__()
        self.transform = transform
        self.train = train
        self.root_dir = root_dir
        path_imgrec = os.path.join(root_dir, 'train.rec')
        patn_imgidx = os.path.join(root_dir, 'train.idx')
        self.imgrec = mx.recordio.MXIndexedRecordIO(patn_imgidx, path_imgrec, "r")
        s = self.imgrec.read_idx(0)
        header, _ = mx.recordio.unpack(s)
        if header.flag > 0:
            print("header0 label:", header.label)
            self.header0 = (int(header.label[0]), int(header.label[1]))
            self.imgidx = list(range(1, int(header.label[0])))
            # print(self.imgidx)
        else:
            self.imgidx = list(self.imgrec.keys)
        print("Number of Samples:{} Number of Classes: {}".format(len(self.imgidx), int(self.header0[1] - self.header0[0])))

    def __getitem__(self, index):
        idx = self.imgidx[index]
        s = self.imgrec.read_idx(idx)
        header, img = mx.recordio.unpack(s)
        label = header.label
        if not isinstance(label, numbers.Number):
            label = label[0]
        label = torch.tensor(label, dtype=torch.long)
        # print(label)

        sample = mx.image.imdecode(img).asnumpy()  # RGB
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, label

    def __len__(self):
        # print(len(self.imgidx))
        return len(self.imgidx)


if __name__ == '__main__':
    root_dir = '/datasets/train'
    trainset = MXFaceDataset(root_dir)
    num_dataset = len(trainset)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
    for batch_idx, (sample, label) in enumerate(train_loader):
        print(sample.shape, label)
  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值