facenet专题2:train_sotfmax.py源代码解析--数据加载

上一篇博文我们介绍了windows下如何运行train_sotfmax.py训练指定的数据库,本文将对train_sotfmax.py代码进行解析。

从主函数main(parse_arguments(sys.argv[1:]))开始,首先看函数parse_arguments,其定义就在train_sotfmax.py文件:

def parse_arguments(argv):
    parser = argparse.ArgumentParser()

    parser.add_argument('--logs_base_dir', type=str,
                        help='Directory where to write event logs.', default='~/logs/facenet')
    parser.add_argument('--models_base_dir', type=str,
                        help='Directory where to write trained models and checkpoints.', default='~/models/facenet')
    parser.add_argument('--gpu_memory_fraction', type=float,
                        help='Upper bound on the amount of GPU memory that will be used by the process.', default=1.0)
    parser.add_argument('--pretrained_model', type=str,
                        help='Load a pretrained model before training starts.')
    parser.add_argument('--data_dir', type=str,
                        help='Path to the data directory containing aligned face patches. Multiple directories are separated with colon.',
..................
return parser.parse_args(argv)

这里用到了parser = argparse.ArgumentParser()实例,add_argument()负责添加参数,指定类型和默认值,最后调用parser.parse_args(argv)对输入的参数修改其值不再是默认值,这里比较容易理解,往下看进入main函数,看第一行代码:

network = importlib.import_module(args.model_def)

若执行上一篇文章的脚本,这里实际执行的是 network = importlib.import_module(models.inception_resnet_v1),即导入models/文件夹下的inception_resnet_v1.py并命名为network,即所使用的inception_resnet_v1网络结构,具体结构后续再做详述。

接下来得到全路径的日志文件夹和模型文件夹:

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    #子文件夹subdir名字为当前时间(年月日时分秒)
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    #凭借得到日志文件夹的全路径log_dir。
    if not os.path.isdir(log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    #不存在则创建一个。
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)
    #同理得到模型文件夹

接着保存git的版本信息到log文件下的文件中(./logs/facenet/revision_info.txt)

    src_path, _ = os.path.split(os.path.realpath(__file__))
    #执行函数os.path.realpath(__file__),结果是__file__的值为执行当前程序的文件名即train_softmax.py,
    #返回值为当前的绝对路径,例如train_softmax.py在F盘facenet文件夹下则返回"F:\\facenet\\train_softmax.py"
    #os.path.split切分为F:\\facenet和train_softmax.py,即src_path为文件所在目录。
    facenet.store_revision_info(src_path, log_dir, ' '.join(sys.argv))
    #' '.join(sys.argv)将sys.argv的所有输入参数按空格隔开拼接成字符串。
def store_revision_info(src_path, output_dir, arg_string):
  
    # Get git hash
    gitproc = Popen(['git', 'rev-parse', 'HEAD'], stdout = PIPE, cwd=src_path)
    #创建子进程执行指令git rev-parse head 查看git仓库提交的SHA1值,也就是提交记录。
    #并将输出指向管道PIPE.
    (stdout, _) = gitproc.communicate()
    #执行该子进程得到标准输出,即git的提交记录号。
    git_hash = stdout.strip()
    #得到提交的记录
  
    # Get local changes
    #同上,比较前后版本号存储到git_diff.
    gitproc = Popen(['git', 'diff', 'HEAD'], stdout = PIPE, cwd=src_path)
    (stdout, _) = gitproc.communicate()
    git_diff = stdout.strip()
    
    # Store a text file in the log directory
    rev_info_filename = os.path.join(output_dir, 'revision_info.txt')
    with open(rev_info_filename, "w") as text_file:
        text_file.write('arguments: %s\n--------------------\n' % arg_string)
        text_file.write('git hash: %s\n--------------------\n' % git_hash)
        text_file.write('%s' % git_diff)

接下来设定随机种子后,调用train_set = facenet.get_dataset(args.data_dir)(train_set = facenet.get_dataset( ./data/CASIA_Web_Face_mtcnnpy_182))获取训练集和,我们来进入get_dataset来看一下是如何得到训练集和的。

def get_dataset(paths):
    dataset = []
    path_exp = paths
    classes = os.listdir(path_exp)
    classes.sort()
    nrof_classes = len(classes)
    for i in range(nrof_classes):
        class_name = classes[i]
        facedir = os.path.join(path_exp, class_name)
        if os.path.isdir(facedir):
            images = os.listdir(facedir)
            image_paths = [os.path.join(facedir,img) for img in images]
            dataset.append(ImageClass(class_name, image_paths))
    return dataset

这段代码比较容易理解,上一篇已经说过,CASIA_Web_Face_mtcnnpy_182文件夹下每一个文件夹对应一个人。这里会依次统计每个人对应的文件夹名字(class_name)和其下面所有带绝对路径的人脸图片文件名(image_paths),存储到ImageClass,ImageCalss其实就是一个两个成员的简单class,names对应class_name,即对应一个人的文件夹名字,image_paths为对应文件夹下所有带绝对路径的人脸图片文件名组成的list,故返回值train_set是一个ImageClass组成的list,list的每个元素是一个ImageClass实例,其包含了训练集某个人的人脸文件信息。接下来执行

    if args.filter_filename:
        train_set = filter_dataset(train_set, os.path.expanduser(args.filter_filename),
                                   args.filter_percentile, args.filter_min_nrof_images_per_class)

如果设定过滤文件的话,会对我们之前得到的训练文件库进行过滤,通常是不需要的。接下来

    pretrained_model = None
    if args.pretrained_model:
        pretrained_model = os.path.expanduser(args.pretrained_model)
        print('Pre-trained model: %s' % pretrained_model)

如果存在预训练的模型获取起绝对路径下的文件名,之前训练中也没有指定。最后执行

    if args.lfw_dir:
        print('LFW directory: %s' % args.lfw_dir)
        # Read the file containing the pairs used for testing
        pairs = lfw.read_pairs(os.path.expanduser(args.lfw_pairs))
        # Get the paths for the corresponding images
        lfw_paths, actual_issame = lfw.get_paths(os.path.expanduser(args.lfw_dir), pairs, args.lfw_file_ext)

这里调用lfw.read_pairs(args.lfw_pairs)用于读取lfw数据库中成对的图像来作为比较数据库。其默认目录是data\pairs.txt。其内容为:

每一行对应某个人的两张图像,例如 Akhmed_Zakayev    2    3表示 Akhmed_Zakayev文件下的 Akhmed_Zakayev_0002.jpg和 Akhmed_Zakayev——0003.jpg图像。返回的pais结构如下:

array([list(['Abel_Pacheco', '1', '4']),
       list(['Akhmed_Zakayev', '1', '3']),
       list(['Akhmed_Zakayev', '2', '3']), ...,
       list(['Shane_Loux', '1', 'Val_Ackerman', '1']),
       list(['Shawn_Marion', '1', 'Shirley_Jones', '1']),
       list(['Slobodan_Milosevic', '2', 'Sok_An', '1'])], dtype=object)

最后调用 lfw_paths, actual_issame = lfw.get_paths(os.path.expanduser(args.lfw_dir), pairs, args.lfw_file_ext),返回lfw数据库中同一人头像对的绝对路径。lfw_pahts为list每一个元素对应两个图像绝对路径组成的组元。issame是同样大小的list,存储bool类型表示对应位置的两个图像是否为同一人。接下来执行代码:

image_list, label_list = facenet.get_image_paths_and_labels(train_set)

我们来看一下函数get_image_paths_and_labels的定义:

def get_image_paths_and_labels(dataset):
    image_paths_flat = []
    labels_flat = []
    for i in range(len(dataset)):
        #list相加会结果是两个list的内容按顺序拼接到一起。
        image_paths_flat += dataset[i].image_paths
        #[i]*5=[i,i,i,i,i],故每个人脸图像生成所有labels拼到一起。
        labels_flat += [i] * len(dataset[i].image_paths)
    return image_paths_flat, labels_flat

由上面分析已经知道这里的输入参数dataset为train_set,其为由ImageClass实例组成的list。这里的作用是将所有人脸图像的带绝对路径的文件名依次存入维度为1的list: image_paths_flat,每一个为其一个元素。并按顺序将每个人的人脸图像依次编号为0、1、2.........,同样依次存入维度为1的list: labels_flat中。故image_list为所有带绝对路径人脸图片文件构成的一维list,而label_list则为对应每一个人脸图片的编号(不同的人编号不同、相同的人编号相同),即image_list按顺序依次存储了所有带路径的头像文件名,label_list依次按顺序存储了图像对应的label。接下来:

labels = ops.convert_to_tensor(label_list, dtype=tf.int32)

返回函数ops.convert_to_tensor将各种类型的Python对象转换为张量对象,详细介绍可参考:https://blog.csdn.net/fangfanglovezhou/article/details/105626072 ,这里label_list被转为张量labels。随后调用

range_size = array_ops.shape(labels)[0]

返回头像文件的个数range_size,即range_size为所有训练图片样本的个数 ,接下来执行

index_queue = tf.train.range_input_producer(range_size, num_epochs=None,
                                                    shuffle=True, seed=None, capacity=32)

tf.train.range_input_producer用于多线程读取队列,详细介绍可参考range_input_producer,调用该函数会产生一个包含range_size个元素的队列(取值:0到range_size-1),接下来:

index_dequeue_op = index_queue.dequeue_many(args.batch_size * args.epoch_size, 'index_dequeue')

表示从队列中取出args.batch_size (默认:90)* args.epoch_size(默认1000)个元素,即每次出队列操作取出90000个元素,接下来


#学习率        
learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate')
#batch_size
batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size')
#phase_train
phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
#训练图像路径
image_paths_placeholder = tf.placeholder(tf.string, shape=(None, 1), name='image_paths')
#图像标签
labels_placeholder = tf.placeholder(tf.int64, shape=(None, 1), name='labels')

输入5个训练的参数,然后:

       #设置队列的参数配置容量、类型以及元素尺寸。
        input_queue = data_flow_ops.FIFOQueue(capacity=100000,
                                              dtypes=[tf.string, tf.int64],
                                              shapes=[(1,), (1,)],
                                              shared_name=None, name=None)
        #根据输入队列的配置来建立输入,按上面配置依次从输入取出数据。
        #(1,)表示只有一个元素的组元
        enqueue_op = input_queue.enqueue_many([image_paths_placeholder, labels_placeholder], name='enqueue_op')
        nrof_preprocess_threads = 4
        images_and_labels = []
        for _ in range(nrof_preprocess_threads):
            #从队列input_queue取出元素。
            filenames, label = input_queue.dequeue()
            images = []
            #将多个文件名filenames分解成单个文件名
            for filename in tf.unstack(filenames):
                #读取头像文件的内容
                file_contents = tf.read_file(filename)
                #解码得到图像矩阵
                image = tf.image.decode_png(file_contents)
                if args.random_rotate:
                    image = tf.py_func(facenet.random_rotate_image, [image], tf.uint8)
                if args.random_crop:
                    image = tf.random_crop(image, [args.image_size, args.image_size, 3])
                else:
                    image = tf.image.resize_image_with_crop_or_pad(image, args.image_size, args.image_size)
                if args.random_flip:
                    image = tf.image.random_flip_left_right(image)

                # pylint: disable=no-member
                #设定图像形状
                image.set_shape((args.image_size, args.image_size, 3))
                #将处理好的图像矩阵存储到images。
                images.append(tf.image.per_image_standardization(image))
            #图像以及label存储在images_and_labels中。
            images_and_labels.append([images, label])
        #多线程按照batch大小从images_and_labels中取出要训练的图像和对应的labels.
        image_batch, label_batch = tf.train.batch_join(
            images_and_labels, batch_size=batch_size_placeholder,
            shapes=[(args.image_size, args.image_size, 3), ()], enqueue_many=True,
            capacity=4 * nrof_preprocess_threads * args.batch_size,
            allow_smaller_final_batch=True)
        image_batch = tf.identity(image_batch, 'image_batch')
        image_batch = tf.identity(image_batch, 'input')
        label_batch = tf.identity(label_batch, 'label_batch')

        print('Total number of classes: %d' % nrof_classes)
        print('Total number of examples: %d' % len(image_list))

        print('Building training graph')

这里代码比较多,功能是建立一个多线程的图模型,实现读取训练数据文件得到训练数据,并对数据进行处理后,按batch取出得到准备好的训练数据。

到此训练集和验证机的数据加载完成,下一篇我们将介绍下网络结构。

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值