关闭

将数据划分为训练数据及测试数据(div_train_val.py 解析)

720人阅读 评论(0) 收藏 举报
分类:

将LFW数据划分为face,non-face两个图像数据文件,在此基础上,提取训练数据及测试数据。

训练数据,在face文件中提取一部分,在non-face文件中提取一部分。

测试数据,在face文件中提取一部分,在non-face文件中提取一部分。

使用 div_train_val.py ,能得到训练数据及测试数据,其文件内容为,图像的路径及标注。

# -*- coding: utf-8 -*-
"""
Created on Mon Jun  8 14:15:21 2015
@brief: 用与划分训练数据,train.list 和 val.list
@author: Riwei Chen <Riwei.Chen@outlook.com>
"""
import  os
def div_database(filepath,savepath,top_num=1000,equal_num=False,full_path=False):
    '''
    @brief: 提取webface人脸数据
    @param : filepath 文件路径
    @param : top_num=1000,表示提取的类别数目,face,non-face -> top_num= 2
    @param : equal_num 是否强制每个人都相同
    '''
    dirlists=os.listdir(filepath)  #crop_images(存放图像文件)文件下的目录
    dict_id_num={}  #定义一个存放子目录长度的元组
    for subdir in dirlists:
        dict_id_num[subdir]=len(os.listdir(os.path.join(filepath,subdir)))  #存储每个子目录下文件的长度(例如face 子目录里面所包含的图像数量)
    #sorted(dict_id_num.items, key=lambda dict_id_num:dict_id_num[1])
    sorted_num_id=sorted([(v, k) for k, v in dict_id_num.items()], reverse=True) #排序,["face",length] -> [length,"face"]
    select_ids=sorted_num_id[0:top_num]
    if equal_num == True:
        trainfile=save_path+'train_'+str(top_num)+'_equal.list'
        testfile=save_path+'val_'+str(top_num)+'_qeual.list'
    else:  #新建训练文件及测试文件
        trainfile=save_path+'train_'+str(top_num)+'.list'
        testfile=save_path+'val_'+str(top_num)+'.list'
    fid_train=open(trainfile,'w') 
    fid_test=open(testfile,'w')
    pid=0
    pre = ""
    if full_path ==True:
        pre = data_path
    #将数据划分为训练数据及测试数据,face 中选取一部分划分为训练数据,另一部分划分为测试数据;non-face中选取一部分划分为训练数据,另一部分划分为测试数据
    for  select_id in select_ids:
        subdir=select_id[1]
        filenamelist=os.listdir(os.path.join(filepath,subdir))  #获取图像文件名
        num=1
        for filename in filenamelist :
            #print select_ids[top_num-1]
            if equal_num==True and num>select_ids[top_num-1][0]:
                break
            if num%10!=0:
                fid_train.write(os.path.join(pre,subdir,filename)+'\t'+str(pid)+'\n')  #保存图像路径及其标注
            else:
                fid_test.write(os.path.join(pre,subdir,filename)+'\t'+str(pid)+'\n')
            num=num+1
        pid=pid+1
    fid_train.close()
    fid_test.close()

if __name__=='__main__':
    data_path = '/home/zhuangni/code/FaceDetection/ReprocessData/alfw/crop_images'
    save_path = '/home/zhuangni/code/FaceDetection/Data/aflw/'
    div_database(data_path,save_path, top_num=2, equal_num=False,full_path =True)



0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:62242次
    • 积分:1053
    • 等级:
    • 排名:千里之外
    • 原创:38篇
    • 转载:16篇
    • 译文:0篇
    • 评论:92条
    最新评论