个人总结
/keras-yolo3-master/yolo3/model.py中的preprocess_true_boxes函数
作用:当读入GT时,传入模型前对GT进行预处理(将gt信息对应到网络的输出上)。
def preprocess_true_boxes(true_boxes, input_shape, anchors, num_classes):
'''Preprocess true boxes to training input format
Parameters
----------
true_boxes: array, shape=(m, T, 5)
Absolute x_min, y_min, x_max, y_max, class_id relative to input_shape
其中m为batch size,T为设定的每张图片的最多样本数,这里为20。,5为具体的GT信息,如上。
input_shape: array-like, hw, multiples of 32,输入图片的size,必为32的倍数。
anchors: array, shape=(N, 2), wh。这里为anchors的size为(9,2)。
num_classes: integer
Returns
-------
y_true: list of array, shape like yolo_outputs, xywh are reletive value
'''
# 判断class id是否超出总的类别数,比如一共只有5类,最大类别代号为4,当出现大于4的类别代号时就报错。
assert (true_boxes[..., 4]<num_classes).all(), 'class id must be less than num_classes'
# 这里为正常版本len(anchors)=9,thiny版本可能有不一样?
num_layers = len(anchors)//3 # default setting
anchor_mask = [[6,7,8], [3,4,5], [0,1,2]] if num_layers==3 else [[3,4,5], [1,2,3]]
true_boxes = np.array(true_boxes, dtype='float32') # (32,20,5)
input_shape = np.array(input_shape, dtype='int32')
boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2# 中心点坐标
boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]# 宽高
true_boxes[..., 0:2] = boxes_xy/input_shape[::-1]# 将x,y,w,h换算为相对原图片大小的一个小数。
true_boxes[..., 2:4] = boxes_wh/input_shape[::-1]
m = true_boxes.shape[0]#batch_size:32
grid_shapes = [input_shape//{0:32, 1:16, 2:8}[l] for l in range(num_layers)]# [array([13, 13]), array([26, 26]), array([52, 52])]表示三个pyramid上特征图的大小。
"""
y_true是一个list
>>> y_true[0].shape
(32, 13, 13, 3, 9)
>>> y_true[1].shape
(32, 26, 26, 3, 9)
>>> y_true[2].shape
(32, 52, 52, 3, 9)
>>> y_true[3].shape
"""
y_true = [np.zeros((m,grid_shapes[l][0],grid_shapes[l][1],len(anchor_mask[l]),5+num_classes),
dtype='float32') for l in range(num_layers)]
# Expand dim to apply broadcasting.
anchors = np.expand_dims(anchors, 0)
anchor_maxes = anchors / 2.
anchor_mins = -anchor_maxes
valid_mask = boxes_wh[..., 0]>0
for b in range(m):# 对batch size循环
# Discard zero rows.
wh = boxes_wh[b, valid_mask[b]]# 读取GT的WH
if len(wh)==0: continue
# Expand dim to apply broadcasting.
wh = np.expand_dims(wh, -2)# Expand dim to apply broadcasting.
# 思想为,将anchors和boxs都以原点为中心,放在一起,计算两者的重叠区域。
box_maxes = wh / 2.
box_mins = -box_maxes
intersect_mins = np.maximum(box_mins, anchor_mins)# 计算重叠区域左上角,通过最大值
intersect_maxes = np.minimum(box_maxes, anchor_maxes)# 计算重叠区域右下角,通过最小值
intersect_wh = np.maximum(intersect_maxes - intersect_mins, 0.)# 计算重叠区域的wh
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]# 计算重叠区域的面积
box_area = wh[..., 0] * wh[..., 1]# bax的面积
anchor_area = anchors[..., 0] * anchors[..., 1]# anchors的面积
# intersect_area的size:(3,9)三个gt,9个anchors。
# box_area的size:(3,1)
# anchor_area的size:(1,9)
# iou的size:(3,9)。表示三个gt和9个anchors分别的重叠比例。
iou = intersect_area / (box_area + anchor_area - intersect_area)# 计算iou
# Find best anchor for each true box,得到一个长度为3的np数组,分别代表与3个gt iou最大的anchor的索引。
best_anchor = np.argmax(iou, axis=-1)
for t, n in enumerate(best_anchor):# 对gt个数进行循环
for l in range(num_layers):# 对三个pyramid进行循环。
if n in anchor_mask[l]:# 找到对应的层级
# floor向下取整
# true_boxes中:b代表对batch循环,不用管;t代表对gt进行循环,0代表该gt的中心点x坐标(已经相对原图归一化到0~1)。grid_shapes代表各个pyramid上特征图的大小,l代表层数,1代表该层的宽。相乘代表将gt对应到该层的特征图上。
i = np.floor(true_boxes[b,t,0]*grid_shapes[l][1]).astype('int32')
# 和上一行大体一样,true_boxes中的1代表中心点y坐标,grid_shapes中0代表高度。
j = np.floor(true_boxes[b,t,1]*grid_shapes[l][0]).astype('int32')
k = anchor_mask[l].index(n)# 每层都会由三个anchors,共三层9个anchor,k表示该gt对应的是该层的第几个anchor。
c = true_boxes[b,t, 4].astype('int32')# c表示对应的类别标签
# y_true是一个list,len(y_true)=3,size分别为(32,13,13,3,9)(32,26,26,3,9)(32,52,52,3,9)。这里将gt按照之前的计算安放到对应的位置上。l表示在pyramid的层数;b表示batch size迭代;j,i表示在特征图上的位置(j表示y坐标,行号;i表示x坐标,列号)。最后位置是4个位置信息,1个置信度信息,4个类别信息。
y_true[l][b, j, i, k, 0:4] = true_boxes[b,t, 0:4]
y_true[l][b, j, i, k, 4] = 1
y_true[l][b, j, i, k, 5+c] = 1
return y_true
网络的输出包含三部分(分别在三个层上采样的结果),每个部分输出一个np,三个np组成一个list。每部分np的shape为[batch_size,h,w,3,5+num_classes],其中h,w为对应特征图的高和宽;3代表特征图上的每个位置产生三个anchor;5表示四个xywh位置信息加一个置信度信息;5+num_classes表示每个特征图上每个位置每个anchor都对应一批位置置信度类别信息。
数据处理
数据的输入实例如下:
/home/tf/keras-yolo3-master/data/VOCdevkit/VOC2007/JPEGImages/2017-12-10112013.jpg 501,391,520,409,0 598,415,615,432,0
/home/tf/keras-yolo3-master/data/VOCdevkit/VOC2007/JPEGImages/2017-12-10112407.jpg 732,309,800,373,2