3、tensorflow2.0 实现MTCNN、训练O_net网络,并进行测试图片

训练O_net网络,并测试图片

  1. 上一篇,我们已经知道如何生成O_net训练集,这次我们开始训练Onet网络。
  2. 训练完成后,保存权重,我们随机抽取一张照片,测试一下效果。

代码:
train_Onet.py


import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import metrics
from red_tf import *
from MTCNN_ import Onet,cls_ohem,cal_accuracy,bbox_ohem
from tqdm import tqdm
import cv2



data_path = "48/train_ONet_landmark.tfrecord_shuffle"


# 加载pokemon数据集的工具!
def load_pokemon(mode='train'):
    """ 加载pokemon数据集的工具!
    :param root:    数据集存储的目录
    :param mode:    mode:当前加载的数据是train,val,还是test
    :return:
    """
    # # 创建数字编码表,范围0-4;
    # name2label = {}  # "sq...":0   类别名:类标签;  字典 可以看一下目录,一共有5个文件夹,5个类别:0-4范围;
    # for name in sorted(os.listdir(os.path.join(root))):     # 列出所有目录;
    #     if not os.path.isdir(os.path.join(root, name)):
    #         continue
    #     # 给每个类别编码一个数字
    #     name2label[name] = len(name2label.keys())

    # 读取Label信息;保存索引文件images.csv
    # [file1,file2,], 对应的标签[3,1] 2个一一对应的list对象。
    # 根据目录,把每个照片的路径提取出来,以及每个照片路径所对应的类别都存储起来,存储到CSV文件中。
    size = 48
    images,labels,boxes = red_tf(data_path,size)

    # 图片切割成,训练70%,验证15%,测试15%。
    if mode == 'train':                                                     # 100% 训练集
        images = images[:int(len(images))]
        labels = labels[:int(len(labels))]
        boxes  = boxes[:int(len(boxes))]
    elif mode == 'val':                                                     # 15% = 70%->85%  验证集
        images = images[int(0.7 * len(images)):int(0.85 * len(images))]
        labels = labels[int(0.7 * len(labels)):int(0.85 * len(labels))]
        boxes = boxes[int(0.7 * len(boxes)):int(0.85 * len(boxes))]
    else:                                                                   # 15% = 70%->85%  测试集
        images = images[int(0.85 * len(images)):]
        labels = labels[int(0.85 * len(labels)):]
        boxes = boxes[int(0.85 * len(boxes)):]
    ima = tf.data.Dataset.from_tensor_slices(images)
    lab = tf.data.Dataset.from_tensor_slices(labels)
    roi = tf.data.Dataset.from_tensor_slices(boxes)

    train_data = tf.data.Dataset.zip((ima, lab, roi)).shuffle(1000).batch(6)
    train_data = list(train_data.as_numpy_iterator())
    return train_data

# 图像色相变换
def image_color_distort(inputs):
    inputs = tf.image.random_contrast(inputs, lower=0.5, upper=1.5)
    inputs = tf.image.random_brightness(inputs, max_delta=0.2)
    inputs = tf.image.random_hue(inputs,max_delta= 0.2)
    inputs = tf.image.random_saturation(inputs,lower = 0.5, upper= 1.5)
    return inputs



def train(eopch):
    model = Onet()
    model.load_weights("onet.h5")

    optimizer = keras.optimizers.Adam(learning_rate=1e-3)
    off = 1000
    acc_meter = metrics.Accuracy()
    for epoch in tqdm(range(eopch)):

        for i,(img,lab,boxes) in enumerate(load_pokemon("train")):


            img = image_color_distort(img)
            # 开一个gradient tape, 计算梯度
            with tf.GradientTape() as tape:
                cls_prob, bbox_pred,laim = model(img)
                cls_loss = cls_ohem(cls_prob, lab)
                bbox_loss = bbox_ohem(bbox_pred, boxes,lab)
                # landmark_loss = landmark_loss_fn(landmark_pred, landmark_batch, label_batch)
                # accuracy = cal_accuracy(cls_prob, label_batch)


                total_loss_value = cls_loss + 0.5 * bbox_loss
                grads = tape.gradient(total_loss_value, model.trainable_variables)
                optimizer.apply_gradients(zip(grads, model.trainable_variables))
            if i % 200 == 0:
                print('Training loss (for one batch) at step %s: %s' % (i, float(total_loss_value)))
                print('Seen so far: %s samples' % ((i + 1) * 6))
           

        for i, (v_img, v_lab1, boxes) in enumerate(load_pokemon("val")):
            v_img = image_color_distort(v_img)
            with tf.GradientTape() as tape:
                cls_prob, bbox_pred,laim= model(v_img)
                cls_loss = cls_ohem(cls_prob, v_lab1)
                bbox_loss = bbox_ohem(bbox_pred, boxes,v_lab1)
                # landmark_loss = landmark_loss_fn(landmark_pred, landmark_batch, label_batch)
                # accuracy = cal_accuracy(cls_prob, label_batch)


                total_loss_value = cls_loss + 0.5 * bbox_loss
                grads = tape.gradient(total_loss_value, model.trainable_variables)
                optimizer.apply_gradients(zip(grads, model.trainable_variables))
            if i % 200 == 0:
                print('val___ loss (for one batch) at step %s: %s' % (i, float(total_loss_value)))
                print('Seen so far: %s samples' % ((i + 1) * 6))
    model.save_weights('./Weights/Onet_wight/onet_30.ckpt')
train(30)

到这里我们已经训练完成了,并把网络权重参数也保存下来,接下来我们开始进行预测。
预测代码:


from Detection.Detect import detect_pnet,detect_Rnet,detect_Onet

import cv2
import numpy as np


def prediction(image_path):
	
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    P_boxes, P_boxes_c = detect_pent(image)

    R_boxes,R_boxes_c = detect_Rnet(image,P_boxes_c)

    O_boxes,O_boxes_c = detect_Onet(image,R_boxes_c)


    # if ret == False:
    #     # 未检测到人脸
    #     print("该图片未检测到人脸")
    for i in range(O_boxes_c.shape[0]):
        bbox = O_boxes_c[i, :4]
        score = O_boxes_c[i, 4]
        corpbbox = [int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])]
        # 画人脸框
        cv2.rectangle(image, (corpbbox[0], corpbbox[1]),
                      (corpbbox[2], corpbbox[3]), (255, 0, 0), 1)
        # 判别为人脸的置信度
        cv2.putText(image, '{:.2f}'.format(score),
                    (corpbbox[0], corpbbox[1] - 2),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)

    cv2.imshow('im', image)
    cv2.waitKey(0)

	cv2.destroyAllWindows()

image = cv2.imread("./24.jpg")	
prediction(image)

在这里插入图片描述
这是该MTCNN检测出来的目标。看来效果不错。

到这里就完。代码中比较重要的回归框我没有表述,这个我觉得我描述的比较菜,我这里推荐一位大神些的边框回归原理,有兴趣的可以去看一下。

边框回归(Bounding Box Regression)详解

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值