基于TensorFlow的MTCNN人脸检测算法( 训练PNet 代码注解 )

本文介绍了如何使用TensorFlow训练MTCNN中的PNet部分,详细解析了训练过程涉及的train_PNet.py文件,包括从mtcnn_model.py导入的P_Net函数模块和train.py中的train函数模块。建议读者具备TensorFlow框架及神经网络基础知识。
摘要由CSDN通过智能技术生成

代码源自Github:https://github.com/AITTSMD/MTCNN-Tensorflow
该阶段代码取自 Git主 代码目录下 MTCNN-Tensorflow-master / train_models / train_PNet.py

#coding:utf-8
#从.py中调用其中的函数
from train_models.mtcnn_model import P_Net
from train_models.train import train

def train_PNet(base_dir, prefix, end_epoch, display, lr):
    """
    定义train PNet
    :param dataset_dir: .tfrecord的路径
    :param prefix:'../data/MTCNN_model/PNet_landmark/PNet'
    :param end_epoch: 用于train的最大的epoch(周期)
    :param display:100
    :param lr: learning rate 学习率
    """
    net_factory = P_Net
    train(net_factory,prefix, end_epoch, base_dir, display=display, base_lr=lr)

if __name__ == '__main__':
    base_dir = '../../DATA/imglists/PNet'									#数据路径 
    model_name = 'MTCNN'
    model_path = '../../data/%s_model/PNet_landmark/PNet' % model_name		#model_path = '../data/MTCNN_model/PNet_landmark/PNet'
            
    prefix = model_path
    end_epoch = 30
    display = 100
    lr = 0.001
    train_PNet(base_dir, prefix, end_epoch, display, lr)

以上是训练PNet部分的代码,看起来十分简短,但是很容易可以看出从.py中调用了两个函数
1.从 train_models 文件夹下的 mtcnn_model.py中调用了 P_Net 函数模块
2.从 train_models 文件夹下的 train.py中调用了 train 函数模块
接下来,我们将这两个函数模块导出来,首先是 train 函数模块的代码如下:
#coding:utf-8


def train(net_factory, prefix, end_epoch, base_dir,
          display=200, base_lr=0.01):												#定义一个训练的工具
    """
    train PNet/RNet/ONet
    :param net_factory:PNet函数
    :param prefix: model path = '../data/MTCNN_model/PNet_landmark/PNet'
    :param end_epoch:30
    :param base_dir:'../../DATA/imglists/PNet',tfrecord文件的路径
    :param dataset:
    :param display:200
    :param base_lr:0.01
    :return:
    """
    
    net = prefix.split('/')[-1]														#net=PNet
    
    label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
    
    print(label_file)											#'../../DATA/imglists/PNet/train_PNet_landmark.txt'
    f = open(label_file, 'r')									#以读取的方式打开train_PNet_landmark.txt
    
    num = len(f.readlines())									#计算训练样本数量							
    print("Total size of the dataset is: ", num)				#打印数据集的总size
    print(prefix)												#打印路径'../data/MTCNN_model/PNet_landmark/PNet'

    #PNet网络使用这个方式来获得数据
    if net == 'PNet':
        #dataset_dir = '../../DATA/imglists/PNet/train_PNet_landmark.tfrecord_shuffle'
        dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
        print('dataset dir is:',dataset_dir)					#打印数据集.tfrecord路径
        image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)	#读取.tfrecord文件
        
    #RNet网络使用3+1个tfrecords来获取数据   
    else:
        pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join('../../DATA/imglists/RNet','landmark_landmark.tfrecord_shuffle')#3+1个.tfrecord路径
        
        dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
        
        #各部分样本数据权重
        pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值