最近在看InsightFace的代码,特别是数据加载那块的代码,由于本身的数据加载太慢(数据量总共是6930097张图片,而有181475个类别),在GTX1070上统计了一下电脑遍历时间:1486.17s(约24分钟),所以想改善下数据加载的方式,改善后数据加载只需要11.45s。
torchvision源码解析(源代码加载):
import torch.utils.data as data
from PIL import Image
import os
import os.path
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()#转换字符串中所有大写字符为小写
return any(filename_lower.endswith(ext) for ext in extensions)
#返回类别数和类别对应的索引值
def find_classes(dir):
#找出所有类别,这块可以自己先找好,保存到txt文件中
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
#给类别排个序
classes.sort()
#给每个类别赋一个索引值
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)#把path中包含的"~"和"~user"转换成用户目录
#遍历获得所有图片
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):#os.walk 的返回值是一个生成器(generator),也就是说我们需要不断的遍历它,来获得所有的内容。
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images #返回元祖(tuple)结构的list,每个元祖包含信息类似:(“图片绝对路径”,类别索引)
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
修改的方法是将
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
这两句的结果保存在了txt文件中,然后在txt文件的数据再读取出来进行训练,保存txt的代码:
import os
import argparse
parser = argparse.ArgumentParser(description="Generating csv file for triplet loss!")
parser.add_argument("-e",'--dataroot', type=str,
help="(REQUIRED) Absolute path to the dataset folder to generate a csv file containing the paths\
of the images for triplet loss.",
default='/home/XXXXX/sdb/Caffe_Project/face_recognition/datasets/msair6'
)
parser.add_argument("-net",'--txt_path', type=str,
help="Required absolute path of the txt file to be generated.",
default='/home/XXXXX/sdb/Caffe_Project/face_recognition/txt/msair6_txt'
)
args = parser.parse_args()
dataroot = args.dataroot
txt_path = args.txt_path
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()#转换字符串中所有大写字符为小写
return any(filename_lower.endswith(ext) for ext in extensions)
#返回类别数和类别对应的索引值
def find_classes(dir):
#找出所有类别,这块可以自己先找好,保存到txt文件中
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
#给类别排个序
classes.sort()
#给每个类别赋一个索引值
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)#把path中包含的"~"和"~user"转换成用户目录
#遍历获得所有图片
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):#os.walk 的返回值是一个生成器
#(generator),也就是说我们需要不断的遍历它,来获得所有的内容。
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = [path, class_to_idx[target]]
images.append(item)
return images #返回元祖(tuple)结构的list,每个元祖包含信息类似:(“图片绝对路径”,类别索引)
def save(filename, docs):
fh = open(filename, 'w')
for key, value in docs.items():
fh.write(key+","+str(value))
fh.write('\n')
fh.close()
def save_list(filename, docs):
fh = open(filename, 'w')
for doc in docs:
fh.write(doc[0]+","+str(doc[1]))
fh.write('\n')
fh.close()
def generate_txt_file(dataroot=None, txt_path=None):
extensions = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
classes, class_to_idx = find_classes(dataroot)
samples = make_dataset(dataroot, class_to_idx, extensions)
# classes_txt = os.path.join(txt_path, "classes.txt")
class_to_idx_txt = os.path.join(txt_path, "class_to_idx.txt")
samples_txt = os.path.join(txt_path, "samples.txt")
# save(classes_txt, classes)
save(class_to_idx_txt, class_to_idx)
save_list(samples_txt, samples)
if __name__ == '__main__':
generate_txt_file(dataroot=dataroot, txt_path=txt_path)
然后将 torchvision源码中的下面两行源码屏蔽掉了:
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
加入了以下的函数:
#read_txt file
classes, class_to_idx, samples = read_txt()
该函数的源码是下面这个样子的:
def read_txt(root='/home/fuxueping/sdb/Caffe_Project/face_recognition/txt/msair6_txt'):
# def read_txt(root='/home/fuxueping/sdb/Caffe_Project/face_recognition/txt/imgs_txt'):
classes = []
class_to_idx = dict()
samples = []
txt_class_to_idx = os.path.join(root, "class_to_idx.txt")
txt_samples = os.path.join(root, "samples.txt")
try:
f_class_to_idx = open(txt_class_to_idx, "r")
except IOError:
print("Error: 没有找到class_to_idx.txt文件或读取文件失败")
else:
print("classes.txt文件读取成功")
for line in f_class_to_idx.readlines(): # Data layer prefetch queue empty
str_line = line.strip('\n').split(',')
class_to_idx[str_line[0]] = int(str_line[1])
classes.append(str_line[0])
f_class_to_idx.close()
try:
f_samples = open(txt_samples, "r")
except IOError:
print("Error: 没有找到classes.txt文件或读取文件失败")
else:
print("samples.txt文件读取成功")
for line in f_samples.readlines(): # Data layer prefetch queue empty
str_line = line.strip('\n').split(',')
int_line_list = (str(str_line[0]), int(str_line[1]))
samples.append(int_line_list)
f_samples.close()
return classes, class_to_idx, samples
这样就只需要生成一次txt,然后多次使用,再遇到多次中断,或数据多次被用作训练样本的时候,是非常节省时间的。