Tensorflow2.0---YOLO V4-tiny网络原理及代码解析(二)- 数据的生成

40 篇文章 2 订阅
19 篇文章 43 订阅

Tensorflow2.0—YOLO V4-tiny网络原理及代码解析(二)- 数据的生成

Tensorflow2.0—YOLO V4-tiny网络原理及代码解析(一)- 特征提取网络中已经把YOLO V4的特征提取网络给讲完了,这篇blog来讲讲数据的生成(其实,v4与v3的数据生成的方式几乎相同)。
首先,来看下真实框编码的主函数位置在哪?
它是在train.py中,这就是它的主函数:data_generator。

def data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes, mosaic=False, random=True):
    '''
    参数信息:
    annotation_lines: ['C:\\Users\\user\\Desktop\\yolov4-tiny-tf2-master/VOCdevkit/VOC2007/JPEGImages/000044.jpg 1,1,370,330,8 99,101,312,213,7\n',
    'C:\\Users\\user\\Desktop\\yolov4-tiny-tf2-master/VOCdevkit/VOC2007/JPEGImages/000039.jpg 156,89,344,279,19\n']
     batch_size: 2
     input_shape:(416, 416)
     anchors:[[ 10.  14.]
            [ 23.  27.]
            [ 37.  58.]
            [ 81.  82.]
            [135. 169.]
            [344. 319.]]
     num_classes:20
    '''

可以在代码中看到,annotation_lines是包含所有训练数据的列表,其中每一个元素就是一张训练数据图片和其对应打过标签的labels。这里为了debug,设置所有训练数据数量为2。

	n = len(annotation_lines)  #总训练数据数量,这里为2
    i = 0
    flag = True
    while True:
        image_data = []
        box_data = []
        for b in range(batch_size):
            if i==0:
                np.random.shuffle(annotation_lines)
            if mosaic:
                if flag and (i+4) < n:
                    image, box = get_random_data_with_Mosaic(annotation_lines[i:i+4], input_shape)
                    i = (i+4) % n
                else:
                    image, box = get_random_data(annotation_lines[i], input_shape, random=random)
                    i = (i+1) % n
                flag = bool(1-flag)
            else:
                image, box = get_random_data(annotation_lines[i], input_shape, random=random)
                i = (i+1) % n
            image_data.append(image)
            box_data.append(box)
        image_data = np.array(image_data)
        box_data = np.array(box_data)

上面这一段代码:首先是创建用于保存图片信息和其对应的gt boxes信息的列表,然后对每个批次中图片数量进行遍历。首先,如果是第一次运行,先随机将训练数据进行打乱。由于该代码中没有使用mosaic数据增强,直接进行常规的数据增强get_random_data(数据增强代码,有机会专门写个blog来讲讲数据增强)。
image_data
在这里插入图片描述
box_data【shape=(2,100,5),表示每个批次2个数据,每条数据有100个gt boxes,0-4表示xyxy,5表示数据哪个类别】:
在这里插入图片描述
然后将box_data喂进preprocess_true_boxes函数中。


下面,来看下preprocess_true_boxes函数:
一、首先,将xyxy格式的gt框转换为xywh格式,并相对于416进行归一化。

#-----------------------------------------------------------#
    #   通过计算获得真实框的中心和宽高
    #   中心点(m,n,2) 宽高(m,n,2)·
    #-----------------------------------------------------------#
    boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2
    boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
    #-----------------------------------------------------------#
    #   将真实框归一化到小数形式
    #-----------------------------------------------------------#
    true_boxes[..., 0:2] = boxes_xy/input_shape[::-1]
    true_boxes[..., 2:4] = boxes_wh/input_shape[::-1]

得到的true_boxes为:
在这里插入图片描述
二、创建y_true

m = true_boxes.shape[0]
    grid_shapes = [input_shape//{0:32, 1:16, 2:8}[l] for l in range(num_layers)]
    #-----------------------------------------------------------#
    #   y_true的格式为(m,13,13,3,25)(m,26,26,3,25)
    #-----------------------------------------------------------#
    y_true = [np.zeros((m,grid_shapes[l][0],grid_shapes[l][1],len(anchor_mask[l]),5+num_classes),
        dtype='float32') for l in range(num_layers)]

三、编码

#-----------------------------------------------------------#
    #   [6,2] -> [1,6,2]
    #-----------------------------------------------------------#
    anchors = np.expand_dims(anchors, 0)
    anchor_maxes = anchors / 2.
    anchor_mins = -anchor_maxes

    #-----------------------------------------------------------#
    #   长宽要大于0才有效
    #-----------------------------------------------------------#
    valid_mask = boxes_wh[..., 0]>0

    for b in range(m):
        # 对每一张图进行处理
        wh = boxes_wh[b, valid_mask[b]]
        if len(wh)==0: continue
        #-----------------------------------------------------------#
        #   [n,2] -> [n,1,2]
        #-----------------------------------------------------------#
        wh = np.expand_dims(wh, -2)
        box_maxes = wh / 2.
        box_mins = -box_maxes

        #-----------------------------------------------------------#
        #   计算所有真实框和先验框的交并比
        #   intersect_area  [n,6]
        #   box_area        [n,1]
        #   anchor_area     [1,6]
        #   iou             [n,6]
        #-----------------------------------------------------------#
        intersect_mins = np.maximum(box_mins, anchor_mins)
        intersect_maxes = np.minimum(box_maxes, anchor_maxes)
        intersect_wh = np.maximum(intersect_maxes - intersect_mins, 0.)
        intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]

        box_area = wh[..., 0] * wh[..., 1]
        anchor_area = anchors[..., 0] * anchors[..., 1]

        iou = intersect_area / (box_area + anchor_area - intersect_area)
        #-----------------------------------------------------------#
        #   维度是[n,] 感谢 消尽不死鸟 的提醒
        #-----------------------------------------------------------#
        best_anchor = np.argmax(iou, axis=-1)

        for t, n in enumerate(best_anchor):
            #-----------------------------------------------------------#
            #   找到每个真实框所属的特征层
            #-----------------------------------------------------------#
            for l in range(num_layers):
                if n in anchor_mask[l]:
                    #-----------------------------------------------------------#
                    #   floor用于向下取整,找到真实框所属的特征层对应的x、y轴坐标
                    #-----------------------------------------------------------#
                    i = np.floor(true_boxes[b,t,0] * grid_shapes[l][1]).astype('int32')
                    j = np.floor(true_boxes[b,t,1] * grid_shapes[l][0]).astype('int32')
                    #-----------------------------------------------------------#
                    #   k指的的当前这个特征点的第k个先验框
                    #-----------------------------------------------------------#
                    k = anchor_mask[l].index(n)
                    #-----------------------------------------------------------#
                    #   c指的是当前这个真实框的种类
                    #-----------------------------------------------------------#
                    c = true_boxes[b, t, 4].astype('int32')
                    #-----------------------------------------------------------#
                    #   y_true的shape为(m,13,13,3,85)(m,26,26,3,85)(m,52,52,3,85)
                    #   最后的85可以拆分成4+1+80,4代表的是框的中心与宽高、
                    #   1代表的是置信度、80代表的是种类
                    #-----------------------------------------------------------#
                    y_true[l][b, j, i, k, 0:4] = true_boxes[b, t, 0:4]
                    y_true[l][b, j, i, k, 4] = 1
                    y_true[l][b, j, i, k, 5+c] = 1

这么长段代码的作用:其实就是先对每一个gt框分别与6个anchor box进行iou求解,求解出最大的iou值,并找到其所在位置。然后在y_true中赋值。

y_true[l][b, j, i, k, 0:4] = true_boxes[b, t, 0:4]
y_true[l][b, j, i, k, 4] = 1
y_true[l][b, j, i, k, 5+c] = 1

最后三行代码的含义:
1.把归一化后的xywh进行位置赋值。
2.将第5个位置是否是物体的概率设置为1。
3.赋值属于到底哪个物体。

在代码中,会出现很多在列表前面加一个*,这是什么意思?
答:表示将列表中的元素进行分开。
在这里插入图片描述

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

进我的收藏吃灰吧~~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值