InsightFace_Pytorch工程:数据加载解读

最近在看InsightFace的代码,特别是数据加载那块的代码,由于本身的数据加载太慢(数据量总共是6930097张图片,而有181475个类别),在GTX1070上统计了一下电脑遍历时间:1486.17s(约24分钟),所以想改善下数据加载的方式,改善后数据加载只需要11.45s。

torchvision源码解析(源代码加载):

import torch.utils.data as data

from PIL import Image

import os
import os.path

def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()#转换字符串中所有大写字符为小写
    return any(filename_lower.endswith(ext) for ext in extensions)

#返回类别数和类别对应的索引值
def find_classes(dir):
    #找出所有类别,这块可以自己先找好,保存到txt文件中
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    #给类别排个序
    classes.sort()
    #给每个类别赋一个索引值
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

def make_dataset(dir, class_to_idx, extensions):
    images = []
    dir = os.path.expanduser(dir)#把path中包含的"~"和"~user"转换成用户目录

    #遍历获得所有图片
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):#os.walk 的返回值是一个生成器(generator),也就是说我们需要不断的遍历它,来获得所有的内容。
            for fname in sorted(fnames):
                if has_file_allowed_extension(fname, extensions):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images #返回元祖(tuple)结构的list,每个元祖包含信息类似:(“图片绝对路径”,类别索引)


class DatasetFolder(data.Dataset):
    """A generic data loader where the samples are arranged in this way: ::

        root/class_x/xxx.ext
        root/class_x/xxy.ext
        root/class_x/xxz.ext

        root/class_y/123.ext
        root/class_y/nsdf3.ext
        root/class_y/asd932_.ext

    Args:
        root (string): Root directory path.
        loader (callable): A function to load a sample given its path.
        extensions (list[string]): A list of allowed extensions.
        transform (callable, optional): A function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
    """

    def __init__(self, root, loader, extensions, transform=None, target_transform=None):
        classes, class_to_idx = find_classes(root)
        samples = make_dataset(root, class_to_idx, extensions)
        if len(samples) == 0:
            raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                               "Supported extensions are: " + ",".join(extensions)))

        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.samples)


    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target


    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']

修改的方法是将

classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)

这两句的结果保存在了txt文件中,然后在txt文件的数据再读取出来进行训练,保存txt的代码:

import os
import argparse

parser = argparse.ArgumentParser(description="Generating csv file for triplet loss!")
parser.add_argument("-e",'--dataroot',  type=str,
                    help="(REQUIRED) Absolute path to the dataset folder to generate a csv file containing the paths\
                     of the images for triplet loss.",
                    default='/home/XXXXX/sdb/Caffe_Project/face_recognition/datasets/msair6'
                    )
parser.add_argument("-net",'--txt_path', type=str,
                    help="Required absolute path of the txt file to be generated.",
                    default='/home/XXXXX/sdb/Caffe_Project/face_recognition/txt/msair6_txt'
                    )
args = parser.parse_args()
dataroot = args.dataroot
txt_path = args.txt_path

def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()#转换字符串中所有大写字符为小写
    return any(filename_lower.endswith(ext) for ext in extensions)

#返回类别数和类别对应的索引值
def find_classes(dir):
    #找出所有类别,这块可以自己先找好,保存到txt文件中
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    #给类别排个序
    classes.sort()
    #给每个类别赋一个索引值
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

def make_dataset(dir, class_to_idx, extensions):
    images = []
    dir = os.path.expanduser(dir)#把path中包含的"~"和"~user"转换成用户目录

    #遍历获得所有图片
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):#os.walk 的返回值是一个生成器
#(generator),也就是说我们需要不断的遍历它,来获得所有的内容。
            for fname in sorted(fnames):
                if has_file_allowed_extension(fname, extensions):
                    path = os.path.join(root, fname)
                    item = [path, class_to_idx[target]]
                    images.append(item)

    return images #返回元祖(tuple)结构的list,每个元祖包含信息类似:(“图片绝对路径”,类别索引)

def save(filename, docs):
    fh = open(filename, 'w')
    for key, value in docs.items():
        fh.write(key+","+str(value))
        fh.write('\n')
    fh.close()

def save_list(filename, docs):
    fh = open(filename, 'w')
    for doc in docs:
        fh.write(doc[0]+","+str(doc[1]))
        fh.write('\n')
    fh.close()

def generate_txt_file(dataroot=None, txt_path=None):
    extensions = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
    classes, class_to_idx = find_classes(dataroot)
    samples = make_dataset(dataroot, class_to_idx, extensions)


    # classes_txt = os.path.join(txt_path, "classes.txt")
    class_to_idx_txt = os.path.join(txt_path, "class_to_idx.txt")
    samples_txt = os.path.join(txt_path, "samples.txt")

    # save(classes_txt, classes)
    save(class_to_idx_txt, class_to_idx)

    save_list(samples_txt, samples)


if __name__ == '__main__':
    generate_txt_file(dataroot=dataroot, txt_path=txt_path)

然后将 torchvision源码中的下面两行源码屏蔽掉了:

classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)

加入了以下的函数:

#read_txt file
classes, class_to_idx, samples = read_txt()

该函数的源码是下面这个样子的: 

def read_txt(root='/home/fuxueping/sdb/Caffe_Project/face_recognition/txt/msair6_txt'):
# def read_txt(root='/home/fuxueping/sdb/Caffe_Project/face_recognition/txt/imgs_txt'):
    classes = []
    class_to_idx = dict()
    samples = []

    txt_class_to_idx = os.path.join(root, "class_to_idx.txt")
    txt_samples = os.path.join(root, "samples.txt")

    try:
        f_class_to_idx = open(txt_class_to_idx, "r")
    except IOError:
        print("Error: 没有找到class_to_idx.txt文件或读取文件失败")
    else:
        print("classes.txt文件读取成功")

        for line in f_class_to_idx.readlines():  # Data layer prefetch queue empty
            str_line = line.strip('\n').split(',')
            class_to_idx[str_line[0]] = int(str_line[1])
            classes.append(str_line[0])
        f_class_to_idx.close()

    try:
        f_samples = open(txt_samples, "r")
    except IOError:
        print("Error: 没有找到classes.txt文件或读取文件失败")
    else:
        print("samples.txt文件读取成功")

        for line in f_samples.readlines():  # Data layer prefetch queue empty
            str_line = line.strip('\n').split(',')
            int_line_list = (str(str_line[0]), int(str_line[1]))
            samples.append(int_line_list)
        f_samples.close()

    return classes, class_to_idx, samples

这样就只需要生成一次txt,然后多次使用,再遇到多次中断,或数据多次被用作训练样本的时候,是非常节省时间的。 

很好的参考链接:https://www.jianshu.com/p/220357ca3342

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猫猫与橙子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值