Method for Data Imbalance
对于分部不均的数据集使用,从而避免long tail distribution。例如CASIA-WebFace
import os
import random
def list_of_groups(init_list, children_list_len):
"""
:param init_list: (list) 放想要分割的list
:param children_list_len: (list) 想要分割成几份
:return:
"""
list_of_groups = zip (*(iter (init_list),) * children_list_len)
end_list = [list (i) for i in list_of_groups]
count = len (init_list) % children_list_len
end_list.append (init_list[-count:]) if count != 0 else end_list
return end_list
def dataset_split(dataset_path, batch_size, select_num):
"""
:param dataset_path: (str)存放子文件夹的目录
:param batch_size: (int)同训练时的batch size
:param select_num: (int)每个文件夹选择的图片个数
:return: (list)整个数据集处理后的文件路径,list中还有list
"""
img_name = [] # 文件夹名
img_num = [] # 文件夹所含图片个数
train_path_list = []
for folders in os.listdir (dataset_path):
img_name.append (folders)
img_folder = os.path.join (dataset_path, folders)
img_num.append (len (os.listdir (img_folder)))
img_name_sep = list_of_groups (img_name, int (batch_size / select_num))
for combined_img_folders in img_name_sep:
for single_img_folder in combined_img_folders:
img_folder_path = os.path.join (dataset_path, single_img_folder)
img_folder_imgs = os.listdir (img_folder_path)
if len (img_folder_imgs) > select_num:
select_img = random.sample (img_folder_imgs, select_num)
path = [img_folder_path + '/' + i for i in select_img]
train_path_list.append (path)
else:
print ('Folder {} failed to fetch'.format (single_img_folder))
return (train_path_list)
if __name__ == '__main__':
casia_folder = r'E:/FaceNet-pytorch/facenet-pytorch--main/datasets/'
train_path = dataset_split (dataset_path=casia_folder,
batch_size=32,
select_num=8)
print(train_path[0])