faster——代码中函数的具体解释

Table of Contents

1.数据准备

​2.rpn网络训练

2.1网络模型

2.2训练网络

2.3rpn的loss

3.rpn_to_roi

3.1Bounding Box Regression

3.2 NMS

4.calc_iou

5.规定使用4个ROI进入训练网络

6.分类网络训练model_classifier

7.分类网络的loss


1.数据准备

2.rpn网络训练

2.1网络模型

rpn网络模型,就是在resnet后添加了两层卷积层

def rpn(base_layers,num_anchors):

    x = Convolution2D(512, (3, 3), padding='same', activation='relu', kernel_initializer='normal', name='rpn_conv1')(base_layers)

    x_class = Convolution2D(num_anchors, (1, 1), activation='sigmoid', kernel_initializer='uniform', name='rpn_out_class')(x)
    x_regr = Convolution2D(num_anchors * 4, (1, 1), activation='linear', kernel_initializer='zero', name='rpn_out_regress')(x)

    return [x_class, x_regr, base_layers]


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None, None, 3 0                                            
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, None, None, 3 0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, None, None, 6 9472        zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
bn_conv1 (FixedBatchNormalizati (None, None, None, 6 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, 6 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, None, None, 6 0           activation_1[0][0]               
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, None, None, 6 4160        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (FixedBatchNormal (None, None, None, 6 256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_2 (Activation)       (None, None, None, 6 0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, None, None, 6 36928       activation_2[0][0]               
__________________________________________________________________________________________________
bn2a_branch2b (FixedBatchNormal (None, None, None, 6 256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_3 (Activation)       (None, None, None, 6 0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, None, None, 2 16640       activation_3[0][0]               
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, None, None, 2 16640       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (FixedBatchNormal (None, None, None, 2 1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (FixedBatchNormali (None, None, None, 2 1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, None, None, 2 0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_4 (Activation)       (None, None, None, 2 0           add_1[0][0]                      
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, None, None, 6 16448       activation_4[0][0]               
__________________________________________________________________________________________________
bn2b_branch2a (FixedBatchNormal (None, None, None, 6 256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_5 (Activation)       (None, None, None, 6 0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, None, None, 6 36928       activation_5[0][0]               
__________________________________________________________________________________________________
bn2b_branch2b (FixedBatchNormal (None, None, None, 6 256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_6 (Activation)       (None, None, None, 6 0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, None, None, 2 16640       activation_6[0][0]               
__________________________________________________________________________________________________
bn2b_branch2c (FixedBatchNormal (None, None, None, 2 1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_2 (Add)                     (None, None, None, 2 0           bn2b_branch2c[0][0]              
                                                                 activation_4[0][0]               
__________________________________________________________________________________________________
activation_7 (Activation)       (None, None, None, 2 0           add_2[0][0]                      
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, None, None, 6 16448       activation_7[0][0]               
__________________________________________________________________________________________________
bn2c_branch2a (FixedBatchNormal (None, None, None, 6 256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_8 (Activation)       (None, None, None, 6 0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, None, None, 6 36928       activation_8[0][0]               
__________________________________________________________________________________________________
bn2c_branch2b (FixedBatchNormal (None, None, None, 6 256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_9 (Activation)       (None, None, None, 6 0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, None, None, 2 16640       activation_9[0][0]               
__________________________________________________________________________________________________
bn2c_branch2c (FixedBatchNormal (None, None, None, 2 1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_3 (Add)                     (None, None, None, 2 0           bn2c_branch2c[0][0]              
                                                                 activation_7[0][0]               
__________________________________________________________________________________________________
activation_10 (Activation)      (None, None, None, 2 0           add_3[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, None, None, 1 32896       activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (FixedBatchNormal (None, None, None, 1 512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_11 (Activation)      (None, None, None, 1 0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, None, None, 1 147584      activation_11[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (FixedBatchNormal (None, None, None, 1 512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_12 (Activation)      (None, None, None, 1 0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, None, None, 5 66048       activation_12[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, None, None, 5 131584      activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (FixedBatchNormal (None, None, None, 5 2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (FixedBatchNormali (None, None, None, 5 2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_4 (Add)                     (None, None, None, 5 0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_13 (Activation)      (None, None, None, 5 0           add_4[0][0]                      
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, None, None, 1 65664       activation_13[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (FixedBatchNormal (None, None, None, 1 512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_14 (Activation)      (None, None, None, 1 0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, None, None, 1 147584      activation_14[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (FixedBatchNormal (None, None, None, 1 512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_15 (Activation)      (None, None, None, 1 0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, None, None, 5 66048       activation_15[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (FixedBatchNormal (None, None, None, 5 2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, None, None, 5 0           bn3b_branch2c[0][0]              
                                                                 activation_13[0][0]              
__________________________________________________________________________________________________
activation_16 (Activation)      (None, None, None, 5 0           add_5[0][0]                      
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, None, None, 1 65664       activation_16[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (FixedBatchNormal (None, None, None, 1 512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_17 (Activation)      (None, None, None, 1 0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, None, None, 1 147584      activation_17[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (FixedBatchNormal (None, None, None, 1 512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_18 (Activation)      (None, None, None, 1 0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, None, None, 5 66048       activation_18[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (FixedBatchNormal (None, None, None, 5 2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_6 (Add)                     (None, None, None, 5 0           bn3c_branch2c[0][0]              
                                                                 activation_16[0][0]              
__________________________________________________________________________________________________
activation_19 (Activation)      (None, None, None, 5 0           add_6[0][0]                      
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, None, None, 1 65664       activation_19[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (FixedBatchNormal (None, None, None, 1 512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_20 (Activation)      (None, None, None, 1 0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, None, None, 1 147584      activation_20[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (FixedBatchNormal (None, None, None, 1 512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_21 (Activation)      (None, None, None, 1 0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, None, None, 5 66048       activation_21[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (FixedBatchNormal (None, None, None, 5 2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, None, None, 5 0           bn3d_branch2c[0][0]              
                                                                 activation_19[0][0]              
__________________________________________________________________________________________________
activation_22 (Activation)      (None, None, None, 5 0           add_7[0][0]                      
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, None, None, 2 131328      activation_22[0][0]              
__________________________________________________________________________________________________
bn4a_branch2a (FixedBatchNormal (None, None, None, 2 1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, None, None, 2 0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, None, None, 2 590080      activation_23[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (FixedBatchNormal (None, None, None, 2 1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_24 (Activation)      (None, None, None, 2 0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, None, None, 1 263168      activation_24[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, None, None, 1 525312      activation_22[0][0]              
__________________________________________________________________________________________________
bn4a_branch2c (FixedBatchNormal (None, None, None, 1 4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (FixedBatchNormali (None, None, None, 1 4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_8 (Add)                     (None, None, None, 1 0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_25 (Activation)      (None, None, None, 1 0           add_8[0][0]                      
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, None, None, 2 262400      activation_25[0][0]              
__________________________________________________________________________________________________
bn4b_branch2a (FixedBatchNormal (None, None, None, 2 1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_26 (Activation)      (None, None, None, 2 0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, None, None, 2 590080      activation_26[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (FixedBatchNormal (None, None, None, 2 1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_27 (Activation)      (None, None, None, 2 0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, None, None, 1 263168      activation_27[0][0]              
__________________________________________________________________________________________________
bn4b_branch2c (FixedBatchNormal (None, None, None, 1 4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, None, None, 1 0           bn4b_branch2c[0][0]              
                                                                 activation_25[0][0]              
__________________________________________________________________________________________________
activation_28 (Activation)      (None, None, None, 1 0           add_9[0][0]                      
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, None, None, 2 262400      activation_28[0][0]              
__________________________________________________________________________________________________
bn4c_branch2a (FixedBatchNormal (None, None, None, 2 1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_29 (Activation)      (None, None, None, 2 0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, None, None, 2 590080      activation_29[0][0]              
__________________________________________________________________________________________________
bn4c_branch2b (FixedBatchNormal (None, None, None, 2 1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_30 (Activation)      (None, None, None, 2 0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, None, None, 1 263168      activation_30[0][0]              
__________________________________________________________________________________________________
bn4c_branch2c (FixedBatchNormal (None, None, None, 1 4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_10 (Add)                    (None, None, None, 1 0           bn4c_branch2c[0][0]              
                                                                 activation_28[0][0]              
__________________________________________________________________________________________________
activation_31 (Activation)      (None, None, None, 1 0           add_10[0][0]                     
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, None, None, 2 262400      activation_31[0][0]              
__________________________________________________________________________________________________
bn4d_branch2a (FixedBatchNormal (None, None, None, 2 1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_32 (Activation)      (None, None, None, 2 0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, None, None, 2 590080      activation_32[0][0]              
__________________________________________________________________________________________________
bn4d_branch2b (FixedBatchNormal (None, None, None, 2 1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_33 (Activation)      (None, None, None, 2 0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, None, None, 1 263168      activation_33[0][0]              
__________________________________________________________________________________________________
bn4d_branch2c (FixedBatchNormal (None, None, None, 1 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, None, None, 1 0           bn4d_branch2c[0][0]              
                                                                 activation_31[0][0]              
__________________________________________________________________________________________________
activation_34 (Activation)      (None, None, None, 1 0           add_11[0][0]                     
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, None, None, 2 262400      activation_34[0][0]              
__________________________________________________________________________________________________
bn4e_branch2a (FixedBatchNormal (None, None, None, 2 1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_35 (Activation)      (None, None, None, 2 0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, None, None, 2 590080      activation_35[0][0]              
__________________________________________________________________________________________________
bn4e_branch2b (FixedBatchNormal (None, None, None, 2 1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_36 (Activation)      (None, None, None, 2 0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, None, None, 1 263168      activation_36[0][0]              
__________________________________________________________________________________________________
bn4e_branch2c (FixedBatchNormal (None, None, None, 1 4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_12 (Add)                    (None, None, None, 1 0           bn4e_branch2c[0][0]              
                                                                 activation_34[0][0]              
__________________________________________________________________________________________________
activation_37 (Activation)      (None, None, None, 1 0           add_12[0][0]                     
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, None, None, 2 262400      activation_37[0][0]              
__________________________________________________________________________________________________
bn4f_branch2a (FixedBatchNormal (None, None, None, 2 1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_38 (Activation)      (None, None, None, 2 0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, None, None, 2 590080      activation_38[0][0]              
__________________________________________________________________________________________________
bn4f_branch2b (FixedBatchNormal (None, None, None, 2 1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_39 (Activation)      (None, None, None, 2 0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, None, None, 1 263168      activation_39[0][0]              
__________________________________________________________________________________________________
bn4f_branch2c (FixedBatchNormal (None, None, None, 1 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_13 (Add)                    (None, None, None, 1 0           bn4f_branch2c[0][0]              
                                                                 activation_37[0][0]              
__________________________________________________________________________________________________
activation_40 (Activation)      (None, None, None, 1 0           add_13[0][0]                     
__________________________________________________________________________________________________
rpn_conv1 (Conv2D)              (None, None, None, 5 4719104     activation_40[0][0]              
__________________________________________________________________________________________________
rpn_out_class (Conv2D)          (None, None, None, 9 4617        rpn_conv1[0][0]                  
__________________________________________________________________________________________________
rpn_out_regress (Conv2D)        (None, None, None, 3 18468       rpn_conv1[0][0]                  
==================================================================================================
Total params: 13,331,373
Trainable params: 13,270,189
Non-trainable params: 61,184

2.2训练网络

那么训练这个网络

			X, Y, img_data = next(data_gen_train)

			loss_rpn = model_rpn.train_on_batch(X, Y)

			P_rpn = model_rpn.predict_on_batch(X)

其中data_gen_train的输出X就是数据准备的600*750*3 图像 Y 就是标记数据【(1,38,47,18),(1,38,47,72)】,img_data就是原始图像(512,640,3)

P_rpn 输出的是rpn网络的预测值,第一个是rpn_out_class (1,38,47,9)每个特征点的ancher包物体的概率,第二个是

rpn_out_regress(1,38,47,9*4),每个特征点的ancher的4个坐标的梯度。

那么是不是很纳闷,为啥训练的数据Y 和预测的P_rpn的维度不一致?

2.3rpn的loss

我们来看看loss就知道了

def rpn_loss_cls(num_anchors):
    def rpn_loss_cls_fixed_num(y_true, y_pred):
        return lambda_rpn_class * K.sum(y_true[:, :, :, :num_anchors] * K.binary_crossentropy(y_pred[:, :, :, :],
                                                                                              y_true[:, :, :,
                                                                                              num_anchors:])) / K.sum(
            epsilon + y_true[:, :, :, :num_anchors])

    return rpn_loss_cls_fixed_num

Y的前9个数据是anchor是否可用*二分类loss(判断是否有物体)(预测的值和Y后9个数据)/ Y的前9个数据是anchor是否可用

说白了就是预测的数据和真实的loss中有很多是不能使用的所以乘以前9个筛选是否可用,然后除以可以的个数,就是每个可用点的平均值。

def class_loss_regr(num_classes):
    def class_loss_regr_fixed_num(y_true, y_pred):
        x = y_true[:, :, 4 * num_classes:] - y_pred
        x_abs = K.abs(x)
        x_bool = K.cast(K.less_equal(x_abs, 1.0), 'float32')
        return lambda_cls_regr * K.sum(
            y_true[:, :, :4 * num_classes] * (x_bool * (0.5 * x * x) + (1 - x_bool) * (x_abs - 0.5))) / K.sum(
            epsilon + y_true[:, :, :4 * num_classes])

    return class_loss_regr_fixed_num

相应的这个也是每个anchor的4个点的梯度和真实的loss ,使用的是均方差,

3.rpn_to_roi

代码第一行就是对预测的bbox的梯度缩小4倍,为什么呢?

我们看get_anchor_gt中第281行

y_rpn_regr[:, y_rpn_regr.shape[1] // 2:, :, :] *= C.std_scaling

这行代码就是把数据的bbox的梯度扩大4倍,个人猜测,把Y的值放大,在训练的时候不会因为数据过小而消失。所以预测数据

P_rpn = model_rpn.predict_on_batch(X)

得到的

P_rpn[0], P_rpn[1] 分别就是每个anchor的包含物体的得分,和每个anchor的4个坐标点的梯度的4倍,所以我们要做bbox回归的时候需要除以4,得到真实的梯度值。

3.1Bounding Box Regression

A[:, :, :, curr_layer] = apply_regr_np(A[:, :, :, curr_layer], regr)

就是把anchor,和梯度回归,得到真实的bbox的大小,获得真实的bbox后,条件筛选后得到

	all_boxes = np.delete(all_boxes, idxs, 0)
	all_probs = np.delete(all_probs, idxs, 0)

all_probs.shape=(16074,4) 16074=38*47*9 就是所有的bbox,all_probs.shape=(16074),就是每个bbox含有物体的概率。到现在为止我们都没做分类,还是在定位,得到标定框。

3.2 NMS

这个bbox也太多了,我们使用nms ,极大值抑制

result = non_max_suppression_fast(all_boxes, all_probs, overlap_thresh=overlap_thresh, max_boxes=max_boxes)[0]

over_thresh=0.9 max_boxes=300,就是只保留300个bbox。

return boxes, probs

返回的就是boxes.shape(300,4) probs(300)

4.calc_iou

输入是标定框和图像数据

calc_iou(R, img_data, C, class_mapping):
    return np.expand_dims(X, axis=0), np.expand_dims(Y1, axis=0), np.expand_dims(Y2, axis=0), IoUs

拿到300个标定框和真实的bbox 做iou ,排除iou小于0.1的后晒出的之剩下5个,用X表示。其中大于0.1,小于0.5的为背景,大于0.5的为物体,5个标定框和每个真实的bbox的iou最大的就是这个标定框的类别,使用one-hot编码格式,因为我的数据只有1类加上一个背景,一共两个。所以Y1.shape=(5,2),其中如果是背景的话,那么y_class_regr_label就是4个0,y_class_regr_coords也是4个0.如果是物体,那么y_class_regr_label就是4个1,y_class_regr_coords也是4个坐标的梯度(也就是bbox regression),所以Y2.shape=(5,8) 就是y_class_regr_label+y_class_regr_coords

5.规定使用4个ROI进入训练网络

			if X2 is None:
				rpn_accuracy_rpn_monitor.append(0)
				rpn_accuracy_for_epoch.append(0)
				continue

			neg_samples = np.where(Y1[0, :, -1] == 1)
			pos_samples = np.where(Y1[0, :, -1] == 0)

			if len(neg_samples) > 0:
				neg_samples = neg_samples[0]
			else:
				neg_samples = []

			if len(pos_samples) > 0:
				pos_samples = pos_samples[0]
			else:
				pos_samples = []
			
			rpn_accuracy_rpn_monitor.append(len(pos_samples))
			rpn_accuracy_for_epoch.append((len(pos_samples)))

			if C.num_rois > 1:
				if len(pos_samples) < C.num_rois//2:
					selected_pos_samples = pos_samples.tolist()
				else:
					selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
				try:
					selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
				except:
					selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()

				sel_samples = selected_pos_samples + selected_neg_samples
			else:
				# in the extreme case where num_rois = 1, we pick a random pos or neg sample
				selected_pos_samples = pos_samples.tolist()
				selected_neg_samples = neg_samples.tolist()
				if np.random.randint(0, 2):
					sel_samples = random.choice(neg_samples)
				else:
					sel_samples = random.choice(pos_samples)

其中sel_samples,就是获得的标定框个交表,如果正样本大于2个,那么就是2个正样本,2个负样本,如果已有一个,那么一个正样本3个负样本,或者0个正样本,4个负样本。

6.分类网络训练model_classifier

loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

X是缩放后的图像,X2[:, sel_samples, :]]就是选定的4个标定框 ,[Y1[:, sel_samples, :]就是4个标定框的类别,

Y2[:, sel_samples, :]]就是4个标定框的4个背景或物体,+4个0或梯度

7.分类网络的loss

因为我这个数据只有一个类,所以最后类别输出是(1,4,2),4是4个roi,2是(0,1) 梯度是(1,4,4)4是4个roi,4是4个梯度

def class_loss_regr(num_classes):
	def class_loss_regr_fixed_num(y_true, y_pred):
		x = y_true[:, :, 4*num_classes:] - y_pred
		x_abs = K.abs(x)
		x_bool = K.cast(K.less_equal(x_abs, 1.0), 'float32')
		return lambda_cls_regr * K.sum(y_true[:, :, :4*num_classes] * (x_bool * (0.5 * x * x) + (1 - x_bool) * (x_abs - 0.5))) / K.sum(epsilon + y_true[:, :, :4*num_classes])
	return class_loss_regr_fixed_num


def class_loss_cls(y_true, y_pred):
	return lambda_cls_class * K.mean(categorical_crossentropy(y_true[0, :, :], y_pred[0, :, :]))

分类的就是多分类loss,直接计算就好

梯度回归的使用后4个和预测的做平方差,然后把前4个为标志为物体的=1的相加求平均

rpn 模型预测结果

classes 模型预测结果

rpn模型的loss

好像0=1+2 1 就是rpn分类loss 2就是rpn 梯度回归的loss

classes 模型的loss

好像0=1+2+3 1是class分类的loss 2是class梯度回归的loss 3是rpn预测的bbox的分类准确率(不确定)

明明loss书写只有两个输出,为啥rpn模型运行后会出现3个,而class模型输入是4个roi,输出应该(4,2)个loss才对啊,为啥只有4个 而且不是每个roi对应一个loss?

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值