代码源自Github:https://github.com/AITTSMD/MTCNN-Tensorflow
该阶段代码取自 Git主 代码目录下 MTCNN-Tensorflow-master / train_models / train_PNet.py
#coding:utf-8
#从.py中调用其中的函数
from train_models.mtcnn_model import P_Net
from train_models.train import train
def train_PNet(base_dir, prefix, end_epoch, display, lr):
"""
定义train PNet
:param dataset_dir: .tfrecord的路径
:param prefix:'../data/MTCNN_model/PNet_landmark/PNet'
:param end_epoch: 用于train的最大的epoch(周期)
:param display:100
:param lr: learning rate 学习率
"""
net_factory = P_Net
train(net_factory,prefix, end_epoch, base_dir, display=display, base_lr=lr)
if __name__ == '__main__':
base_dir = '../../DATA/imglists/PNet' #数据路径
model_name = 'MTCNN'
model_path = '../../data/%s_model/PNet_landmark/PNet' % model_name #model_path = '../data/MTCNN_model/PNet_landmark/PNet'
prefix = model_path
end_epoch = 30
display = 100
lr = 0.001
train_PNet(base_dir, prefix, end_epoch, display, lr)
以上是训练PNet部分的代码,看起来十分简短,但是很容易可以看出从.py中调用了两个函数
1.从 train_models 文件夹下的 mtcnn_model.py中调用了 P_Net 函数模块
2.从 train_models 文件夹下的 train.py中调用了 train 函数模块
接下来,我们将这两个函数模块导出来,首先是 train 函数模块的代码如下:
#coding:utf-8
def train(net_factory, prefix, end_epoch, base_dir,
display=200, base_lr=0.01): #定义一个训练的工具
"""
train PNet/RNet/ONet
:param net_factory:PNet函数
:param prefix: model path = '../data/MTCNN_model/PNet_landmark/PNet'
:param end_epoch:30
:param base_dir:'../../DATA/imglists/PNet',tfrecord文件的路径
:param dataset:
:param display:200
:param base_lr:0.01
:return:
"""
net = prefix.split('/')[-1] #net=PNet
label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
print(label_file) #'../../DATA/imglists/PNet/train_PNet_landmark.txt'
f = open(label_file, 'r') #以读取的方式打开train_PNet_landmark.txt
num = len(f.readlines()) #计算训练样本数量
print("Total size of the dataset is: ", num) #打印数据集的总size
print(prefix) #打印路径'../data/MTCNN_model/PNet_landmark/PNet'
#PNet网络使用这个方式来获得数据
if net == 'PNet':
#dataset_dir = '../../DATA/imglists/PNet/train_PNet_landmark.tfrecord_shuffle'
dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
print('dataset dir is:',dataset_dir) #打印数据集.tfrecord路径
image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net) #读取.tfrecord文件
#RNet网络使用3+1个tfrecords来获取数据
else:
pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
landmark_dir = os.path.join('../../DATA/imglists/RNet','landmark_landmark.tfrecord_shuffle')#3+1个.tfrecord路径
dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
#各部分样本数据权重
pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0