MTCNN(Tensorflow)学习记录(RNet的训练)

上一篇为RNet网络生成TFRecord文件,现在开始对RNet进行训练。

1 RNet的训练

进入train_models文件夹打开train_RNet.py,代码如下:

#coding:utf-8
from train_models.mtcnn_model import R_Net
from train_models.train import train


def train_RNet(base_dir, prefix, end_epoch, display, lr):
    """
    train PNet
    :param dataset_dir: tfrecord path
    :param prefix:
    :param end_epoch:
    :param display:
    :param lr:
    :return:
    """
    net_factory = R_Net
    train(net_factory, prefix, end_epoch, base_dir, display=display, base_lr=lr)
    #net_factory = R_Net
    #prefix='../data/MTCNN_model/RNet_Landmark/RNet'
    #end_epoch=22 
    #base_dir='../../DATA/imglists/RNet'
    #display = 100
    #base_lr=0.001

if __name__ == '__main__':
    base_dir = '../../DATA/imglists_noLM/RNet'

    model_name = 'MTCNN'
    model_path = '../data/%s_model/RNet_No_Landmark/RNet' % model_name
    prefix = model_path
    end_epoch = 22
    display = 100
    lr = 0.001
    train_RNet(base_dir, prefix, end_epoch, display, lr)

由上可以看出调用了R_Nettrain这两个函数,我们在这里将这两个导出来,train函数的代码如下:

def train(net_factory, prefix, end_epoch, base_dir,
          display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    #net_factory = R_Net
    #prefix='../data/MTCNN_model/RNet_Landmark/RNet'
    #end_epoch=22 
    #base_dir='../../DATA/imglists/RNet'
    #display = 100
    #base_lr=0.001
    net = prefix.split('/')[-1]      #net=RNet
    #label file
    label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
    #label_file ='../../DATA/imglists/RNet/train_RNet_landmark.txt'
    print(label_file)
    f = open(label_file, 'r')
 
    num = len(f.readlines())                     #获得训练样本的数量
    print("Total size of the dataset is: ", num) #prefix='../data/MTCNN_model/RNet_Landmark/RNet'
    print(prefix)

    if net == 'PNet':
        dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
        print('dataset dir is:',dataset_dir)
        image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
        
    #RNet使用了四个tfrecord文件获取数据 
    else:
        pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
        #'../../DATA/imglists/RNet/pos_landmark.tfrecord_shuffle'
        part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
        #'../../DATA/imglists/RNet/part_landmark.tfrecord_shuffle'
        neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
        #'../../DATA/imglists/RNet/neg_landmark.tfrecord_shuffle'
        landmark_dir = os.path.join('../../DATA/imglists/RNet','landmark_landmark.tfrecord_shuffle')
        #'../../DATA/imglists/RNet/landmark_landmark.tfrecord_shuffle'
        #将四个tfrecord文件的路径传入列表dataset_dirs
        dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
        
        pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6
        #pos:part:landmark:neg=1:1:1:3
        pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
        assert pos_batch_size != 0,"Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
        assert part_batch_size != 0,"Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
        assert neg_batch_size != 0,"Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
        assert landmark_batch_size != 0,"Batch Size Error "
        batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]
        #batch_sizes = [64, 64, 192, 64]
        image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)
        #使用read_multi_tfrecord()函数获得数据,这个函数和PNet训练中读取tfrecord文件使用的的read_single_tfrecord()函数基本类似,可以对照着看,这里就不多做注释了 
        
    #  
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
        #RNet三个训练损失的权重比为1:0.5:0.5
    else:
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1;
        image_size = 48
    
    #定义placeholder
    input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')
    #通过net_factory(RNet)获得loss和accuracy值
    input_image = image_color_distort(input_image)
    cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)
    #接下来的代码和PNet的训练过程一模一样,就不做详细介绍了
    total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op
    train_op, lr_op = train_model(base_lr,
                                  total_loss_op,
                                  num)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()


    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)

    #visualize some variables
    tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
    tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
    tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
    tf.summary.scalar("total_loss",total_loss_op)#cls_loss, bbox loss, landmark loss and L2 loss add together
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" %(net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer,projector_config)
    #begin 
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()
    try:



        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
            #random flip
            image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
            '''
            print('im here')
            print(image_batch_array.shape)
            print(label_batch_array.shape)
            print(bbox_batch_array.shape)
            print(landmark_batch_array.shape)
            print(label_batch_array[0])
            print(bbox_batch_array[0])
            print(landmark_batch_array[0])
            '''


            _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
            
            if (step+1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})

                total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + radio_landmark_loss*landmark_loss + L2_loss
                # landmark loss: %4f,
                print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f " % (
                datetime.now(), step+1,MAX_STEP, acc, cls_loss, bbox_loss,landmark_loss, L2_loss,total_loss, lr))


            #save every two epochs
            if i * config.BATCH_SIZE > num*2:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess, prefix, global_step=epoch*2)
                print('path prefix is :', path_prefix)
            writer.add_summary(summary,global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()

R_Net()(在这个脚本里以net_factory的形式存在)函数的介绍在此,RNet的训练和PNet的训练差不多,大家可以互相对应,大多数相同的的部分我就省略了。

RNet的训练数据来源是经过图像金字塔获得不同size的图像,然后再通过NMS(非极大值抑制)算法筛选出合格的人脸框。将得到的训练数据resize成24*24。RNet的网络最后的输出的是通过全连接层, 和PNet有差异,但是损失计算的方法都是同样的。

RNet的训练中,每循环完两个数据周期,就记录一次模型的参数。

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值