MTCNN(Tensorflow)学习记录(PNet的训练)

上一篇博客是为PNet网络生成TFRecord文件,现在开始对PNet进行训练。

1 PNet的训练

进入train_models文件夹打开train_PNet.py,代码如下:

#coding:utf-8
from train_models.mtcnn_model import P_Net
from train_models.train import train


def train_PNet(base_dir, prefix, end_epoch, display, lr):
    """
    train PNet
    :param dataset_dir: tfrecord path
    :param prefix:
    :param end_epoch: max epoch for training
    :param display:
    :param lr: learning rate
    :return:
    """
    #base_dir: tfrecord文件的路径
    #prefix:'../data/MTCNN_model/PNet_landmark/PNet'
    #end_epoch:训练的最大周期
    #display:100
    #lr:学习率
    net_factory = P_Net
    train(net_factory,prefix, end_epoch, base_dir, display=display, base_lr=lr)

if __name__ == '__main__':
    #data path
    base_dir = '../../DATA/imglists/PNet'
    model_name = 'MTCNN'
    #model_path = '../data/%s_model/PNet/PNet' % model_name
    #with landmark
    model_path = '../data/%s_model/PNet_landmark/PNet' % model_name
            
    prefix = model_path
    end_epoch = 30
    display = 100
    lr = 0.001
    train_PNet(base_dir, prefix, end_epoch, display, lr)

由上可以看出调用了P_Nettrain这两个函数,我们在这里将这两个导出来,train函数的代码如下:

def train(net_factory, prefix, end_epoch, base_dir,
          display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    #net_factory:P_Net函数
    #prefix:'../data/MTCNN_model/PNet_landmark/PNet'
    #end_epoch:30
    #base_dir:'../../DATA/imglists/PNet',tfrecord文件的路径
    #displah:传进来的值是100
    #lr:传进来的值是0.001
    
    net = prefix.split('/')[-1]                   #net=PNet
    #label file
    label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net) 
    #'../../DATA/imglists/PNet/train_PNet_landmark.txt'

    print(label_file)                     		#打印路径'../../DATA/imglists/PNet/train_PNet_landmark.txt'
    f = open(label_file, 'r')       			    #打开train_PNet_landmark.txt
    # get number of training examples
    num = len(f.readlines())                      #计算总数据量,num=1429422
    print("Total size of the dataset is: ", num)
    print(prefix)                          		#打印路径'../data/MTCNN_model/PNet_landmark/PNet'

    #PNet use this method to get data
    if net == 'PNet':
        dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
        #dataset_dir='../../DATA/imglists/PNet/train_PNet_landmark.tfrecord_shuffle'
        print('dataset dir is:',dataset_dir)
        #打印路径
        image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
        #读取tfrecord文件
    
	#RNet use 3 tfrecords to get data    
    else:
        pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
        #landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join('../../DATA/imglists/RNet','landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
        pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
        assert pos_batch_size != 0,"Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
        assert part_batch_size != 0,"Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
        assert neg_batch_size != 0,"Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
        assert landmark_batch_size != 0,"Batch Size Error "
        batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]
        #print('batch_size is:', batch_sizes)
        image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)        
        
    #landmark_dir    
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    else:
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1;
        image_size = 48
    
    #define placeholder
    #输入图片,形式为[384,12,12,3]
    input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
    #输入label,形式为[384]
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    #输入bbox,形式为[384,4]
    bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
    #输入landmark,形式为[384,10]
    landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')
    #get loss and accuracy
    #对图片进行色彩调整
    input_image = image_color_distort(input_image)
    #获得人类分类训练、bounding box训练、landmark训练、正则化的损失和人脸分类训练的准确率
    cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)
    #train,update learning rate(3 loss)
    #将三个损失乘上各自的权重再加上正则化损失得到总损失
    total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op
    #得到模型的学习率和train_op
    train_op, lr_op = train_model(base_lr,
                                  total_loss_op,
                                  num)
    # 将所有变量初始化
    init = tf.global_variables_initializer()
    sess = tf.Session()


    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)

    #visualize some variables
    #使用Tensorbroad,可视化这些数据的变化
    tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
    tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
    tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
    tf.summary.scalar("total_loss",total_loss_op)#cls_loss, bbox loss, landmark loss and L2 loss add together
    #保存所有的summary
    summary_op = tf.summary.merge_all()
    #保存目录的创建
    logs_dir = "../logs/%s" %(net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    #将文件写进目录
    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    #通过projector.ProjectorConfig()类来帮助生成日志文件
    projector_config = projector.ProjectorConfig()
    #将projector的内容写入日志文件
    projector.visualize_embeddings(writer,projector_config)
    #begin 
    #使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象
    coord = tf.train.Coordinator()
    #begin enqueue thread
    #启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    #总批次
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    #防止内存溢出
    sess.graph.finalize()
    try:



        for step in range(MAX_STEP):
            i = i + 1
            #使用 coord.should_stop()来查询是否应该终止所有线程,
            #当文件队列(queue)中的所有文件都已经读取出列的时候,
            #会抛出一个 OutofRangeError 的异常,这时候就应该停止Sesson中的所有线程
            if coord.should_stop():
                break
            #数据读取
            image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
            #random flip
            #随机翻转图片
            image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
            '''
            print('im here')
            print(image_batch_array.shape)
            print(label_batch_array.shape)
            print(bbox_batch_array.shape)
            print(landmark_batch_array.shape)
            print(label_batch_array[0])
            print(bbox_batch_array[0])
            print(landmark_batch_array[0])
            '''


            _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
			#每过200个step就打印时间和各种损失
            if (step+1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})

                total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + radio_landmark_loss*landmark_loss + L2_loss
                # landmark loss: %4f,
                print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f " % (
                datetime.now(), step+1,MAX_STEP, acc, cls_loss, bbox_loss,landmark_loss, L2_loss,total_loss, lr))


            #save every two epochs
            #每两个周期保存一次
            if i * config.BATCH_SIZE > num*2:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess, prefix, global_step=epoch*2)
                print('path prefix is :', path_prefix)
            writer.add_summary(summary,global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()

用到了read_single_tfrecord函数、P_Net(在这个脚本里以net_factory的形式存在)函数、train_model()函数、random_flip_images()函数

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值