tensorflow2.0 实现MTCNN,生成R_net数据集,并训练

1.生成R_net数据集,并喂入R_net网络训练

为了提高R_net数据的质量,我们的R_net数据集会通过P_net精炼一下。

把回归框中对应的图片从原理图片中crop出来,resize生成24*24的大小。

运行P_Net,得到每个回归框的分类得分和坐标的回归值。

P_Net中推理出来的 bounding box 坐标可能超出元素图片的大小,这时候需要做图片处理,把坐标值限定在元素图片内,而空出来的数值为 0。

对于回归框,根据得分情况,执行非极大值抑制算法,得到要校正的回归框。

对P_Net得到的回归框,根据RNet运算得到的回归值,执行回归框的校正操作,得到最终的bounding box的坐标。

进入代码

1.1下面代码是P_net推断。利用里面的detect_pent(),来精炼 R_net数据集。Detect.py

import tensorflow as tf
import numpy as np
import cv2
from MTCNN_ import Pnet,Rnet,Onet


min_face_size = 20



#预测处理数据
def processed_img(img, scale):
    '''预处理数据,转化图像尺度并对像素归一到[-1,1]
    '''
    h,w,_ = img.shape
    n_h = int(h*scale)
    n_w = int(w*scale)
    dsize = (n_w,n_h)
    img_resized = cv2.resize(np.array(img), dsize,interpolation=cv2.INTER_LINEAR)
    img_resized = (img_resized - 127.5)/128
    return img_resized

# 生成边框
def generate_bounding_box(cls_pro,bbox_pred,scale,threshold):

    stride = 2
    cellsize = 12
    # softmax layer 1 for face, return a tuple with an array of row idxs and
    # an array of col idxs
    # locate face above threshold from cls_map
    t_index = np.where(cls_pro > threshold)

    # find nothing
    if t_index[0].size == 0:
        return np.array([])
        # 偏移量
    bbox_pred = bbox_pred[t_index[0], t_index[1], :]
    bbox_pred = np.reshape(bbox_pred, (-1, 4))
    score = cls_pro[t_index[0], t_index[1]]
    score = np.reshape(score, (-1, 1))

    x1Arr = np.round((stride * t_index[1]) / scale)
    x1Arr = np.reshape(x1Arr, (-1, 1))
    y1Arr = np.round((stride * t_index[0]) / scale)
    y1Arr = np.reshape(y1Arr, (-1, 1))
    x2Arr = np.round((stride * t_index[1] + cellsize) / scale)
    x2Arr = np.reshape(x2Arr, (-1, 1))
    y2Arr = np.round((stride * t_index[0] + cellsize) / scale)
    y2Arr = np.reshape(y2Arr, (-1, 1))

    bboxes = np.concatenate([x1Arr, y1Arr, x2Arr, y2Arr, score, bbox_pred], -1)


    return bboxes

#校正边框
def calibrate_box(bboxes,offsets):
    """

    :param bboxes: [n,5]
    :param offsets: [n,4]
    :return: [n,5]
    """

    x1,y1,x2,y2 = [bboxes[:,i] for i in range(4)]
    w = x2 - x1 + 1.0
    h = y2 - y1 + 1.0

    w = np.expand_dims(w,1)
    h = np.expand_dims(h,1)

    translation = np.hstack([w,h,w,h]) * offsets
    bboxes[:,0:4] = bboxes[:,0:4] + translation

    return bboxes

# 非极大抑制
def nms(dets, thresh):
    '''剔除太相似的box'''
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    # 将概率值从大到小排列
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h

        ovr = inter / (areas[i] + areas[order[1:]] - inter + 1e-10)

        # 保留小于阈值的下标,因为order[0]拿出来做比较了,所以inds+1是原来对应的下标
        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep

#矩形转正方形
def convert_to_square(bboxes):
    """
    将边框转换成正方形
    :param bboxes: [n,5]
    :return:
    """
    square_bboxes = np.zeros_like(bboxes)
    x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]

    h = y2 - y1 + 1.0
    w = x2 - x1 + 1.0
    max_side = np.maximum(h, w)
    square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
    square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
    square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
    square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
    return square_bboxes

#处理超出范围的边框
def pad(bboxes, w, h):
    '''将超出图像的box进行处理
    参数:
      bboxes:人脸框
      w,h:图像长宽
    返回值:
      dy, dx : 为调整后的box的左上角坐标相对于原box左上角的坐标
      edy, edx : 为调整后的box右下角相对原box左上角的相对坐标
      y, x : 调整后的box在原图上左上角的坐标
      ey, ex : 调整后的box在原图上右下角的坐标
    '''
    tw, th = bboxes[:, 2] - bboxes[:, 0] + 1, bboxes[:, 3] - bboxes[:, 1] + 1
    n_box = bboxes.shape[0]

    dx, dy = np.zeros((n_box,)), np.zeros((n_box,))
    edx, edy = tw.copy() - 1, th.copy() - 1
    # box左上右下的坐标
    x, y, ex, ey = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
    # 找到超出右下边界的box并将ex,ey归为图像的w,h
    # edx,edy为调整后的box右下角相对原box左上角的相对坐标

    tmp_index = np.where(ex > w - 1)
    edx[tmp_index] = tw[tmp_index]  - 1 - (ex[tmp_index] - w + 1)
    ex[tmp_index] = w - 1

    tmp_index = np.where(ey > h - 1)
    edy[tmp_index] = th[tmp_index] - 1 - (ey[tmp_index] - h + 1)
    ey[tmp_index] = h - 1

    # 找到超出左上角的box并将x,y归为0
    # dx,dy为调整后的box的左上角坐标相对于原box左上角的坐标
    tmp_index = np.where(x < 0)
    dx[tmp_index] = 0 - x[tmp_index]
    x[tmp_index] = 0

    tmp_index = np.where(y < 0)
    dy[tmp_index] = 0 - y[tmp_index]
    y[tmp_index] = 0

    return_list = [dy, edy, dx, edx, y, ey, x, ex, tw, th]
    return_list = [item.astype(np.int32) for item in return_list]

    return return_list

    #return x.astype(np.int32), y.astype(np.int32), ex.astype(np.int32), ey.astype(np.int32), dx.astype(np.int32), dy.astype(np.int32), edx.astype(np.int32), edy.astype(np.int32)


def detect_pent(image):
    """

    :param image: 要预测的图片
    :return: 校准后的预测方框
    """
    num_thresh = 0.7
    scale_factor = 0.709
    P_thresh = 0.5
    model = Pnet()
    model.load_weights("./Weights/pnet_wight/pnet_30.ckpt")



    net_size = 12
    current_scale = float(net_size) / min_face_size

    im_resized = processed_img(image,current_scale)
    # print("im_resized",im_resized.shape)
    current_h,current_w,_ = im_resized.shape

    # im_resize = im_resized.reshape(1, *im_resized.shape)

    all_boxes = list()

    while min(current_h,current_w) > net_size:
        #因为Pnet要求的数据是[b,w,h,3] 所以在[w,h,3] 0维添加一列
        img_resized = tf.expand_dims(im_resized, axis=0)
        img_resized = tf.cast(img_resized, tf.float32)

        cls_prob, bbox_pred = model.predict(img_resized)

        cls_prob = cls_prob[0]
        bbox_pred = bbox_pred[0]

        bboxes = generate_bounding_box(cls_prob[:,:,1],bbox_pred,current_scale,0.6)
        # print("bboxes",bboxes)
        current_scale *= scale_factor

        im_resized = processed_img(image,current_scale)
        current_h, current_w, _ = im_resized.shape




        if bboxes.size == 0:
            continue


        keep = nms(bboxes[:, :5], 0.5)
        bboxes = bboxes[keep]
        all_boxes.append(bboxes)

    if len(all_boxes) == 0:
        return None
    all_boxes = np.vstack(all_boxes)
    keep = nms(all_boxes[:, :5], 0.7)
    all_boxes = all_boxes[keep]

    boxes = np.copy(all_boxes[:, :5])
    bbw = all_boxes[:, 2] - all_boxes[:, 0] + 1
    bbh = all_boxes[:, 3] - all_boxes[:, 1] + 1
    x1Arr = all_boxes[:, 0] + all_boxes[:, 5] * bbw
    y1Arr = all_boxes[:, 1] + all_boxes[:, 6] * bbh
    x2Arr = all_boxes[:, 2] + all_boxes[:, 7] * bbw
    y2Arr = all_boxes[:, 3] + all_boxes[:, 8] * bbh
    scoreArr = all_boxes[:, 4]


    boxes_c = np.concatenate([x1Arr.reshape(-1, 1),
                              y1Arr.reshape(-1, 1),
                              x2Arr.reshape(-1, 1),
                              y2Arr.reshape(-1, 1),
                              scoreArr.reshape(-1, 1)],
                             axis=-1)
    return boxes,boxes_c


def detect_Rnet(img,dets):
    '''通过rent选择box
    参数:
      im:输入图像
      dets:pnet选择的box,是相对原图的绝对坐标
    返回值:
      box绝对坐标
    '''

    model = Rnet()
    model.load_weights("rnet.h5")
    h,w,_ = img.shape
    #将pnet的box变成包含他的正方形,可以避免信息损失

    dets = convert_to_square(dets)
    # print("dets",dets)
    dets[:,0:4] = np.round(dets[:,0:4])

    [dy, edy, dx, edx, y, ey, x, ex, dw, dh] = pad(dets,w,h)
    # print("dy",dw)
    delete_size = np.ones_like(dw) * 20
    ones = np.ones_like(dw)
    zeros = np.zeros_like(dw)
    num_boxes = np.sum(np.where((np.minimum(dw, dh) >= delete_size), ones, zeros))
    cropped_imgs = np.zeros((num_boxes, 24, 24, 3), dtype=np.float32)

    for i in range(num_boxes):


        # 将pnet生成的box相对与原图进行裁剪,超出部分用0补
        # 将pnet生成的box相对与原图进行裁剪,超出部分用0补
        if dh[i] < 20 or dw[i] < 20:
            continue
        tmp = np.zeros((dh[i],dw[i],3),dtype=np.uint8)
        tmp[dy[i]:edy[i]+1,dx[i]:edx[i]+1,:] = img[y[i]:ey[i]+1,x[i]:ex[i]+1,:]

        cropped_imgs[i,:,:,:] = (cv2.resize(tmp,(24,24)) - 127.5) / 128

    cls_scores, reg = model.predict(cropped_imgs)



    cls_scores = cls_scores[:, 1]

    keep_inds = np.where(cls_scores > 0.7)[0]

    if len(keep_inds) > 0:
        boxes = dets[keep_inds]
        boxes[:, 4] = cls_scores[keep_inds]
        reg = reg[keep_inds]
        pass
    else:
        return None, None
        pass

    keep = nms(boxes, 0.6)
    boxes = boxes[keep]
    # 对pnet截取的图像的坐标进行校准,生成rnet的人脸框对于原图的绝对坐标
    boxes_c = calibrate_box(boxes, reg[keep])
    return boxes, boxes_c

def detect_Onet(img,dets):

    """
     将onet的选框继续筛选基本和rnet差不多但多返回了landmark
    :param img:
    :param dets: rnet_
    :return:
    """
    model = Onet()
    model.load_weights("./Weights/Onet_wight/onet_30.ckpt")
    h,w,_ = img.shape
    dets = convert_to_square(dets)
    dets[:,0:4] = np.round(dets[:,0:4])
    [dy, edy, dx, edx, y, ey, x, ex, dw, dh] = pad(dets,w,h)

    n_boxes = dets.shape[0]

    cropped_imgs = np.zeros((n_boxes,48,48,3),dtype=np.float32)
    for i in range(n_boxes):

        tmp = np.zeros((dh[i], dw[i], 3), dtype=np.uint8)
        tmp[dy[i]:edy[i] + 1, dx[i]:edx[i] + 1, :] = img[y[i]:ey[i] + 1, x[i]:ex[i] + 1, :]
        cropped_imgs[i, :, :, :] = (cv2.resize(tmp, (48, 48)) - 127.5) / 128


    cls_scores,reg,linm = model.predict(cropped_imgs)

    cls_scores = cls_scores[:,1]

    keep_inds = np.where(cls_scores > 0.7)[0]

    if len(keep_inds) > 0:
        boxes = dets[keep_inds]
        boxes[:,4] = cls_scores[keep_inds]
        reg = reg[keep_inds]

    else:
        return None,None


    # h = boxes[:,3] - boxes[:,1] + 1
    # w = boxes[:.2] - boxes[:,0] + 1

    boxes_c = calibrate_box(boxes,reg)
    boxes = boxes[nms(boxes,0.6)]
    keep = nms(boxes_c,0.6)

    boxes_c = boxes_c[keep]

    return boxes,boxes_c


下面该脚本 可以生成 Rnet数据集,和Onet数据集

gen_Rnet_Onet_data.py

import numpy as np
import cv2
import os,sys

from detect import detect_pent,detect_Rnet

RorO=24 #24 Rnet,这里修改48可以生成 Onet数据集

#这里修改48可以生成 Onet数据集
stdsize = RorO


# im_dir = "samples"
pos_save_dir = str(stdsize) + "/positive"
part_save_dir = str(stdsize) + "/part"
neg_save_dir = str(stdsize) + '/negative'
#这里修改48可以生成 Onet数据集
save_dir = str(RorO)



def IoU(box, boxes):
    """Compute IoU between detect box and gt boxes

    Parameters:
    -----------
    box: numpy array , shape (5, ): x1, y1, x2, y2, score
        input box
    boxes: numpy array, shape (n, 4): x1, y1, x2, y2
        input ground truth boxes

    Returns:
    --------
    ovr: numpy.array, shape (n, )
        IoU
    """
    # box = (x1, y1, x2, y2)
    box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
    area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)

    # abtain the offset of the interception of union between crop_box and gt_box
    xx1 = np.maximum(box[0], boxes[:, 0])
    yy1 = np.maximum(box[1], boxes[:, 1])
    xx2 = np.minimum(box[2], boxes[:, 2])
    yy2 = np.minimum(box[3], boxes[:, 3])

    # compute the width and height of the bounding box
    w = np.maximum(0, xx2 - xx1 + 1)
    h = np.maximum(0, yy2 - yy1 + 1)

    inter = w * h
    ovr = inter / (box_area + area - inter)
    return ovr





def mkr(dr):
    if not os.path.exists(dr):
        os.mkdir(dr)

mkr(save_dir)
mkr(pos_save_dir)
mkr(part_save_dir)
mkr(neg_save_dir)

# 生成一系列txt文档用于存储Positive,Negative,Part三类数据的信息
f1 = open(os.path.join(save_dir, 'pos_' + str(stdsize) + '.txt'), 'w')
f2 = open(os.path.join(save_dir, 'neg_' + str(stdsize) + '.txt'), 'w')
f3 = open(os.path.join(save_dir, 'part_' + str(stdsize) + '.txt'), 'w')



annotations= np.load("labels8.npy")#[:100,4].astype(np.float32)
imgs=np.load("imgs8.npy")

num = len(annotations)
print("%d pics in total" % num)
p_idx = 0 # positive
n_idx = 0 # negative
d_idx = 0 # dont care
idx = 0
box_idx = 0

#这里修改48可以生成 Onet数据集
size_i = RorO

def convert_to_square(bbox):
    """Convert bbox to square

    Parameters:
    ----------
    bbox: numpy array , shape n x 5
        input bbox

    Returns:
    -------
    square bbox
    """
    square_bbox = bbox.copy()

    h = bbox[:, 3] - bbox[:, 1] + 1
    w = bbox[:, 2] - bbox[:, 0] + 1
    max_side = np.maximum(h,w)
    square_bbox[:, 0] = bbox[:, 0] + w*0.5 - max_side*0.5
    square_bbox[:, 1] = bbox[:, 1] + h*0.5 - max_side*0.5
    square_bbox[:, 2] = square_bbox[:, 0] + max_side - 1
    square_bbox[:, 3] = square_bbox[:, 1] + max_side - 1
    return square_bbox


#len(annotations) 只需要100张图片
for i in range(100):
    boxes = annotations[i][0:8].reshape(-1, 2)
    ix=boxes[:,0].min()
    iy=boxes[:,1].min()
    ax=boxes[:,0].max()
    ay=boxes[:,1].max()
    boxes=np.array([[ix,iy,ax,ay]])

    image = imgs[i].copy()
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    if size_i == 24:
        bbos, pre_bboxes= detect_pent(image)
    else:
        pen_bbos, pen_pre_bboxes = detect_pent(image) #Pnet
        bbos, pre_bboxes = detect_Rnet(image,pen_pre_bboxes) #Rnet

    print("盒子数",len(pre_bboxes))

    if len(pre_bboxes) == 0:
        continue


    pre_bboxes = np.array(pre_bboxes)
    dets = convert_to_square(pre_bboxes)

    dets[:, 0:4] = np.round(dets[:, 0:4])

    img = imgs[i]
    idx += 1

    height,width,channel = img.shape

    neg_num = 0
    for box in dets:

        x_left,y_top,x_right,y_bottom = box[0:4].astype(int)
        width = x_right - x_left + 1
        height = y_bottom - y_top + 1
        if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1:
            continue


        Iou = IoU(box,boxes)

        cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :]

        resized_im = cv2.resize(cropped_im, (stdsize, stdsize),
                                interpolation=cv2.INTER_LINEAR)

        # if np.max(Iou) < 0.2 and n_idx < 3.0 * p_idx + 1:
        if np.max(Iou) < 0.3 and neg_num < 60:
            save_file = os.path.join(neg_save_dir,"%s.jpg"%n_idx)
            f2.write(str(stdsize)+"/negative/%s"% n_idx + " 0\n")
            cv2.imwrite(save_file,resized_im)
            n_idx += 1
            neg_num += 1

        else:
            idx_Iou = np.argmax(Iou)
            assigned_gt = boxes[idx_Iou]
            x1,y1,x2,y2 = assigned_gt

            offset_x1 = (x1 - x_left) / float(width)
            offset_y1 = (y1 - y_top) / float(height)
            offset_x2 = (x2 - x_right) / float(width)
            offset_y2 = (y2 - y_bottom) / float(height)
            if np.max(Iou) >= 0.65:
                save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx)
                f1.write(str(stdsize)+"/positive/%s"%p_idx + ' 1 %.2f %.2f %.2f %.2f\n' % (
                    offset_x1, offset_y1, offset_x2, offset_y2))
                cv2.imwrite(save_file, resized_im)
                p_idx += 1

            elif np.max(Iou) >= 0.4 and d_idx < 1.0 * p_idx + 1:
                save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx)
                f3.write(str(stdsize)+"/part/%s"%d_idx + ' -1 %.2f %.2f %.2f %.2f\n' % (
                    offset_x1, offset_y1, offset_x2, offset_y2))
                cv2.imwrite(save_file, resized_im)
                d_idx += 1

    print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx))


f1.close()
f2.close()
f3.close()

运行writelabel.py,gen_tfrecord.py,生成tfrecod文件

运行下面的脚本训练Rnet

train_Rnet.py

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import metrics
from read_tfrecord import *
from MTCNN_ import Rnet,cls_ohem,cal_accuracy,bbox_ohem
from tqdm import tqdm




data_path = "24/train_RNet_landmark.tfrecord_shuffle"



# 加载pokemon数据集的工具!
def load_pokemon(mode='train'):
    """ 加载pokemon数据集的工具!
    :param root:    数据集存储的目录
    :param mode:    mode:当前加载的数据是train,val,还是test
    :return:
    """
    # # 创建数字编码表,范围0-4;
    # name2label = {}  # "sq...":0   类别名:类标签;  字典 可以看一下目录,一共有5个文件夹,5个类别:0-4范围;
    # for name in sorted(os.listdir(os.path.join(root))):     # 列出所有目录;
    #     if not os.path.isdir(os.path.join(root, name)):
    #         continue
    #     # 给每个类别编码一个数字
    #     name2label[name] = len(name2label.keys())

    # 读取Label信息;保存索引文件images.csv
    # [file1,file2,], 对应的标签[3,1] 2个一一对应的list对象。
    # 根据目录,把每个照片的路径提取出来,以及每个照片路径所对应的类别都存储起来,存储到CSV文件中。
    size = 24
    images,labels,boxes = red_tf(data_path,size)

    # 图片切割成,训练70%,验证15%,测试15%。
    if mode == 'train':                                                     # 100% 训练集
        images = images[:int(len(images))]
        labels = labels[:int(len(labels))]
        boxes  = boxes[:int(len(boxes))]
    elif mode == 'val':                                                     # 15% = 70%->85%  验证集
        images = images[int(0.7 * len(images)):int(0.85 * len(images))]
        labels = labels[int(0.7 * len(labels)):int(0.85 * len(labels))]
        boxes = boxes[int(0.7 * len(boxes)):int(0.85 * len(boxes))]
    else:                                                                   # 15% = 70%->85%  测试集
        images = images[int(0.85 * len(images)):]
        labels = labels[int(0.85 * len(labels)):]
        boxes = boxes[int(0.85 * len(boxes)):]
    ima = tf.data.Dataset.from_tensor_slices(images)
    lab = tf.data.Dataset.from_tensor_slices(labels)
    roi = tf.data.Dataset.from_tensor_slices(boxes)

    train_data = tf.data.Dataset.zip((ima, lab, roi)).shuffle(1000).batch(16)
    train_data = list(train_data.as_numpy_iterator())
    return train_data




# 图像色相变换
def image_color_distort(inputs):
    inputs = tf.image.random_contrast(inputs, lower=0.5, upper=1.5)
    inputs = tf.image.random_brightness(inputs, max_delta=0.2)
    inputs = tf.image.random_hue(inputs,max_delta= 0.2)
    inputs = tf.image.random_saturation(inputs,lower = 0.5, upper= 1.5)
    return inputs


def train(eopch):
    model = Rnet()
    #model.load_weights("rnet.h5")

    optimizer = keras.optimizers.Adam(learning_rate=1e-3)
    off = 1000
    acc_meter = metrics.Accuracy()
    for epoch in tqdm(range(eopch)):

        for i,(img,lab,boxes) in enumerate(load_pokemon("train")):


            img = image_color_distort(img)
            # 开一个gradient tape, 计算梯度
            with tf.GradientTape() as tape:
                cls_prob, bbox_pred = model(img)
                cls_loss = cls_ohem(cls_prob, lab)
                bbox_loss = bbox_ohem(bbox_pred, boxes,lab)
                # landmark_loss = landmark_loss_fn(landmark_pred, landmark_batch, label_batch)
                # accuracy = cal_accuracy(cls_prob, label_batch)


                total_loss_value = cls_loss + 0.5 * bbox_loss
                grads = tape.gradient(total_loss_value, model.trainable_variables)
                optimizer.apply_gradients(zip(grads, model.trainable_variables))
            if i % 200 == 0:
                print('Training loss (for one batch) at step %s: %s' % (i, float(total_loss_value)))
                print('Seen so far: %s samples' % ((i + 1) * 6))


        for i, (v_img, v_lab1, boxes) in enumerate(load_pokemon("val")):
            v_img = image_color_distort(v_img)
            with tf.GradientTape() as tape:
                cls_prob, bbox_pred = model(v_img)
                cls_loss = cls_ohem(cls_prob, v_lab1)
                bbox_loss = bbox_ohem(bbox_pred, boxes,v_lab1)
                # landmark_loss = landmark_loss_fn(landmark_pred, landmark_batch, label_batch)
                # accuracy = cal_accuracy(cls_prob, label_batch)


                total_loss_value = cls_loss + 0.5 * bbox_loss
                grads = tape.gradient(total_loss_value, model.trainable_variables)
                optimizer.apply_gradients(zip(grads, model.trainable_variables))
            if i % 200 == 0:
                print('val___ loss (for one batch) at step %s: %s' % (i, float(total_loss_value)))
                print('Seen so far: %s samples' % ((i + 1) * 6))
    model.save_weights('./Weights/Rnet_wight/rnet_30.ckpt')
train(30)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值