Table of Contents
1.数据准备
2.rpn网络训练
2.1网络模型
rpn网络模型,就是在resnet后添加了两层卷积层
def rpn(base_layers,num_anchors):
x = Convolution2D(512, (3, 3), padding='same', activation='relu', kernel_initializer='normal', name='rpn_conv1')(base_layers)
x_class = Convolution2D(num_anchors, (1, 1), activation='sigmoid', kernel_initializer='uniform', name='rpn_out_class')(x)
x_regr = Convolution2D(num_anchors * 4, (1, 1), activation='linear', kernel_initializer='zero', name='rpn_out_regress')(x)
return [x_class, x_regr, base_layers]
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, None, None, 3 0
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, None, None, 3 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, None, None, 6 9472 zero_padding2d_1[0][0]
__________________________________________________________________________________________________
bn_conv1 (FixedBatchNormalizati (None, None, None, 6 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, None, None, 6 0 bn_conv1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, None, None, 6 0 activation_1[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, None, None, 6 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (FixedBatchNormal (None, None, None, 6 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, None, None, 6 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, None, None, 6 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (FixedBatchNormal (None, None, None, 6 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, None, None, 6 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, None, None, 2 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, None, None, 2 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (FixedBatchNormal (None, None, None, 2 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (FixedBatchNormali (None, None, None, 2 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, None, None, 2 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, None, None, 2 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, None, None, 6 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (FixedBatchNormal (None, None, None, 6 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, None, None, 6 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, None, None, 6 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (FixedBatchNormal (None, None, None, 6 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, None, None, 6 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, None, None, 2 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (FixedBatchNormal (None, None, None, 2 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, None, None, 2 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, None, None, 2 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, None, None, 6 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (FixedBatchNormal (None, None, None, 6 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, None, None, 6 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, None, None, 6 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (FixedBatchNormal (None, None, None, 6 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, None, None, 6 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, None, None, 2 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (FixedBatchNormal (None, None, None, 2 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, None, None, 2 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, None, None, 2 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, None, None, 1 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (FixedBatchNormal (None, None, None, 1 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, None, None, 1 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, None, None, 1 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (FixedBatchNormal (None, None, None, 1 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, None, None, 1 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, None, None, 5 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, None, None, 5 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (FixedBatchNormal (None, None, None, 5 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (FixedBatchNormali (None, None, None, 5 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, None, None, 5 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, None, None, 5 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, None, None, 1 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (FixedBatchNormal (None, None, None, 1 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, None, None, 1 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, None, None, 1 147584 activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (FixedBatchNormal (None, None, None, 1 512 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, None, None, 1 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, None, None, 5 66048 activation_15[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (FixedBatchNormal (None, None, None, 5 2048 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, None, None, 5 0 bn3b_branch2c[0][0]
activation_13[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, None, None, 5 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, None, None, 1 65664 activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (FixedBatchNormal (None, None, None, 1 512 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, None, None, 1 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, None, None, 1 147584 activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (FixedBatchNormal (None, None, None, 1 512 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, None, None, 1 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, None, None, 5 66048 activation_18[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (FixedBatchNormal (None, None, None, 5 2048 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, None, None, 5 0 bn3c_branch2c[0][0]
activation_16[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, None, None, 5 0 add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D) (None, None, None, 1 65664 activation_19[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (FixedBatchNormal (None, None, None, 1 512 res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, None, None, 1 0 bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D) (None, None, None, 1 147584 activation_20[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (FixedBatchNormal (None, None, None, 1 512 res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, None, None, 1 0 bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D) (None, None, None, 5 66048 activation_21[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (FixedBatchNormal (None, None, None, 5 2048 res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, None, None, 5 0 bn3d_branch2c[0][0]
activation_19[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, None, None, 5 0 add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, None, None, 2 131328 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (FixedBatchNormal (None, None, None, 2 1024 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, None, None, 2 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, None, None, 2 590080 activation_23[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (FixedBatchNormal (None, None, None, 2 1024 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, None, None, 2 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, None, None, 1 263168 activation_24[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, None, None, 1 525312 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (FixedBatchNormal (None, None, None, 1 4096 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (FixedBatchNormali (None, None, None, 1 4096 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, None, None, 1 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, None, None, 1 0 add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, None, None, 2 262400 activation_25[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (FixedBatchNormal (None, None, None, 2 1024 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_26 (Activation) (None, None, None, 2 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, None, None, 2 590080 activation_26[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (FixedBatchNormal (None, None, None, 2 1024 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_27 (Activation) (None, None, None, 2 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, None, None, 1 263168 activation_27[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (FixedBatchNormal (None, None, None, 1 4096 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, None, None, 1 0 bn4b_branch2c[0][0]
activation_25[0][0]
__________________________________________________________________________________________________
activation_28 (Activation) (None, None, None, 1 0 add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, None, None, 2 262400 activation_28[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (FixedBatchNormal (None, None, None, 2 1024 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_29 (Activation) (None, None, None, 2 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, None, None, 2 590080 activation_29[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (FixedBatchNormal (None, None, None, 2 1024 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_30 (Activation) (None, None, None, 2 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, None, None, 1 263168 activation_30[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (FixedBatchNormal (None, None, None, 1 4096 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, None, None, 1 0 bn4c_branch2c[0][0]
activation_28[0][0]
__________________________________________________________________________________________________
activation_31 (Activation) (None, None, None, 1 0 add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, None, None, 2 262400 activation_31[0][0]
__________________________________________________________________________________________________
bn4d_branch2a (FixedBatchNormal (None, None, None, 2 1024 res4d_branch2a[0][0]
__________________________________________________________________________________________________
activation_32 (Activation) (None, None, None, 2 0 bn4d_branch2a[0][0]
__________________________________________________________________________________________________
res4d_branch2b (Conv2D) (None, None, None, 2 590080 activation_32[0][0]
__________________________________________________________________________________________________
bn4d_branch2b (FixedBatchNormal (None, None, None, 2 1024 res4d_branch2b[0][0]
__________________________________________________________________________________________________
activation_33 (Activation) (None, None, None, 2 0 bn4d_branch2b[0][0]
__________________________________________________________________________________________________
res4d_branch2c (Conv2D) (None, None, None, 1 263168 activation_33[0][0]
__________________________________________________________________________________________________
bn4d_branch2c (FixedBatchNormal (None, None, None, 1 4096 res4d_branch2c[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, None, None, 1 0 bn4d_branch2c[0][0]
activation_31[0][0]
__________________________________________________________________________________________________
activation_34 (Activation) (None, None, None, 1 0 add_11[0][0]
__________________________________________________________________________________________________
res4e_branch2a (Conv2D) (None, None, None, 2 262400 activation_34[0][0]
__________________________________________________________________________________________________
bn4e_branch2a (FixedBatchNormal (None, None, None, 2 1024 res4e_branch2a[0][0]
__________________________________________________________________________________________________
activation_35 (Activation) (None, None, None, 2 0 bn4e_branch2a[0][0]
__________________________________________________________________________________________________
res4e_branch2b (Conv2D) (None, None, None, 2 590080 activation_35[0][0]
__________________________________________________________________________________________________
bn4e_branch2b (FixedBatchNormal (None, None, None, 2 1024 res4e_branch2b[0][0]
__________________________________________________________________________________________________
activation_36 (Activation) (None, None, None, 2 0 bn4e_branch2b[0][0]
__________________________________________________________________________________________________
res4e_branch2c (Conv2D) (None, None, None, 1 263168 activation_36[0][0]
__________________________________________________________________________________________________
bn4e_branch2c (FixedBatchNormal (None, None, None, 1 4096 res4e_branch2c[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, None, None, 1 0 bn4e_branch2c[0][0]
activation_34[0][0]
__________________________________________________________________________________________________
activation_37 (Activation) (None, None, None, 1 0 add_12[0][0]
__________________________________________________________________________________________________
res4f_branch2a (Conv2D) (None, None, None, 2 262400 activation_37[0][0]
__________________________________________________________________________________________________
bn4f_branch2a (FixedBatchNormal (None, None, None, 2 1024 res4f_branch2a[0][0]
__________________________________________________________________________________________________
activation_38 (Activation) (None, None, None, 2 0 bn4f_branch2a[0][0]
__________________________________________________________________________________________________
res4f_branch2b (Conv2D) (None, None, None, 2 590080 activation_38[0][0]
__________________________________________________________________________________________________
bn4f_branch2b (FixedBatchNormal (None, None, None, 2 1024 res4f_branch2b[0][0]
__________________________________________________________________________________________________
activation_39 (Activation) (None, None, None, 2 0 bn4f_branch2b[0][0]
__________________________________________________________________________________________________
res4f_branch2c (Conv2D) (None, None, None, 1 263168 activation_39[0][0]
__________________________________________________________________________________________________
bn4f_branch2c (FixedBatchNormal (None, None, None, 1 4096 res4f_branch2c[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, None, None, 1 0 bn4f_branch2c[0][0]
activation_37[0][0]
__________________________________________________________________________________________________
activation_40 (Activation) (None, None, None, 1 0 add_13[0][0]
__________________________________________________________________________________________________
rpn_conv1 (Conv2D) (None, None, None, 5 4719104 activation_40[0][0]
__________________________________________________________________________________________________
rpn_out_class (Conv2D) (None, None, None, 9 4617 rpn_conv1[0][0]
__________________________________________________________________________________________________
rpn_out_regress (Conv2D) (None, None, None, 3 18468 rpn_conv1[0][0]
==================================================================================================
Total params: 13,331,373
Trainable params: 13,270,189
Non-trainable params: 61,184
2.2训练网络
那么训练这个网络
X, Y, img_data = next(data_gen_train)
loss_rpn = model_rpn.train_on_batch(X, Y)
P_rpn = model_rpn.predict_on_batch(X)
其中data_gen_train的输出X就是数据准备的600*750*3 图像 Y 就是标记数据【(1,38,47,18),(1,38,47,72)】,img_data就是原始图像(512,640,3)
P_rpn 输出的是rpn网络的预测值,第一个是rpn_out_class (1,38,47,9)每个特征点的ancher包物体的概率,第二个是
rpn_out_regress(1,38,47,9*4),每个特征点的ancher的4个坐标的梯度。
那么是不是很纳闷,为啥训练的数据Y 和预测的P_rpn的维度不一致?
2.3rpn的loss
我们来看看loss就知道了
def rpn_loss_cls(num_anchors):
def rpn_loss_cls_fixed_num(y_true, y_pred):
return lambda_rpn_class * K.sum(y_true[:, :, :, :num_anchors] * K.binary_crossentropy(y_pred[:, :, :, :],
y_true[:, :, :,
num_anchors:])) / K.sum(
epsilon + y_true[:, :, :, :num_anchors])
return rpn_loss_cls_fixed_num
Y的前9个数据是anchor是否可用*二分类loss(判断是否有物体)(预测的值和Y后9个数据)/ Y的前9个数据是anchor是否可用
说白了就是预测的数据和真实的loss中有很多是不能使用的所以乘以前9个筛选是否可用,然后除以可以的个数,就是每个可用点的平均值。
def class_loss_regr(num_classes):
def class_loss_regr_fixed_num(y_true, y_pred):
x = y_true[:, :, 4 * num_classes:] - y_pred
x_abs = K.abs(x)
x_bool = K.cast(K.less_equal(x_abs, 1.0), 'float32')
return lambda_cls_regr * K.sum(
y_true[:, :, :4 * num_classes] * (x_bool * (0.5 * x * x) + (1 - x_bool) * (x_abs - 0.5))) / K.sum(
epsilon + y_true[:, :, :4 * num_classes])
return class_loss_regr_fixed_num
相应的这个也是每个anchor的4个点的梯度和真实的loss ,使用的是均方差,
3.rpn_to_roi
代码第一行就是对预测的bbox的梯度缩小4倍,为什么呢?
我们看get_anchor_gt中第281行
y_rpn_regr[:, y_rpn_regr.shape[1] // 2:, :, :] *= C.std_scaling
这行代码就是把数据的bbox的梯度扩大4倍,个人猜测,把Y的值放大,在训练的时候不会因为数据过小而消失。所以预测数据
P_rpn = model_rpn.predict_on_batch(X)
得到的
P_rpn[0], P_rpn[1] 分别就是每个anchor的包含物体的得分,和每个anchor的4个坐标点的梯度的4倍,所以我们要做bbox回归的时候需要除以4,得到真实的梯度值。
3.1Bounding Box Regression
A[:, :, :, curr_layer] = apply_regr_np(A[:, :, :, curr_layer], regr)
就是把anchor,和梯度回归,得到真实的bbox的大小,获得真实的bbox后,条件筛选后得到
all_boxes = np.delete(all_boxes, idxs, 0)
all_probs = np.delete(all_probs, idxs, 0)
all_probs.shape=(16074,4) 16074=38*47*9 就是所有的bbox,all_probs.shape=(16074),就是每个bbox含有物体的概率。到现在为止我们都没做分类,还是在定位,得到标定框。
3.2 NMS
这个bbox也太多了,我们使用nms ,极大值抑制
result = non_max_suppression_fast(all_boxes, all_probs, overlap_thresh=overlap_thresh, max_boxes=max_boxes)[0]
over_thresh=0.9 max_boxes=300,就是只保留300个bbox。
return boxes, probs
返回的就是boxes.shape(300,4) probs(300)
4.calc_iou
输入是标定框和图像数据
calc_iou(R, img_data, C, class_mapping):
return np.expand_dims(X, axis=0), np.expand_dims(Y1, axis=0), np.expand_dims(Y2, axis=0), IoUs
拿到300个标定框和真实的bbox 做iou ,排除iou小于0.1的后晒出的之剩下5个,用X表示。其中大于0.1,小于0.5的为背景,大于0.5的为物体,5个标定框和每个真实的bbox的iou最大的就是这个标定框的类别,使用one-hot编码格式,因为我的数据只有1类加上一个背景,一共两个。所以Y1.shape=(5,2),其中如果是背景的话,那么y_class_regr_label就是4个0,y_class_regr_coords也是4个0.如果是物体,那么y_class_regr_label就是4个1,y_class_regr_coords也是4个坐标的梯度(也就是bbox regression),所以Y2.shape=(5,8) 就是y_class_regr_label+y_class_regr_coords
5.规定使用4个ROI进入训练网络
if X2 is None:
rpn_accuracy_rpn_monitor.append(0)
rpn_accuracy_for_epoch.append(0)
continue
neg_samples = np.where(Y1[0, :, -1] == 1)
pos_samples = np.where(Y1[0, :, -1] == 0)
if len(neg_samples) > 0:
neg_samples = neg_samples[0]
else:
neg_samples = []
if len(pos_samples) > 0:
pos_samples = pos_samples[0]
else:
pos_samples = []
rpn_accuracy_rpn_monitor.append(len(pos_samples))
rpn_accuracy_for_epoch.append((len(pos_samples)))
if C.num_rois > 1:
if len(pos_samples) < C.num_rois//2:
selected_pos_samples = pos_samples.tolist()
else:
selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
try:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
except:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()
sel_samples = selected_pos_samples + selected_neg_samples
else:
# in the extreme case where num_rois = 1, we pick a random pos or neg sample
selected_pos_samples = pos_samples.tolist()
selected_neg_samples = neg_samples.tolist()
if np.random.randint(0, 2):
sel_samples = random.choice(neg_samples)
else:
sel_samples = random.choice(pos_samples)
其中sel_samples,就是获得的标定框个交表,如果正样本大于2个,那么就是2个正样本,2个负样本,如果已有一个,那么一个正样本3个负样本,或者0个正样本,4个负样本。
6.分类网络训练model_classifier
loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])
X是缩放后的图像,X2[:, sel_samples, :]]就是选定的4个标定框 ,[Y1[:, sel_samples, :]就是4个标定框的类别,
Y2[:, sel_samples, :]]就是4个标定框的4个背景或物体,+4个0或梯度
7.分类网络的loss
因为我这个数据只有一个类,所以最后类别输出是(1,4,2),4是4个roi,2是(0,1) 梯度是(1,4,4)4是4个roi,4是4个梯度
def class_loss_regr(num_classes):
def class_loss_regr_fixed_num(y_true, y_pred):
x = y_true[:, :, 4*num_classes:] - y_pred
x_abs = K.abs(x)
x_bool = K.cast(K.less_equal(x_abs, 1.0), 'float32')
return lambda_cls_regr * K.sum(y_true[:, :, :4*num_classes] * (x_bool * (0.5 * x * x) + (1 - x_bool) * (x_abs - 0.5))) / K.sum(epsilon + y_true[:, :, :4*num_classes])
return class_loss_regr_fixed_num
def class_loss_cls(y_true, y_pred):
return lambda_cls_class * K.mean(categorical_crossentropy(y_true[0, :, :], y_pred[0, :, :]))
分类的就是多分类loss,直接计算就好
梯度回归的使用后4个和预测的做平方差,然后把前4个为标志为物体的=1的相加求平均
rpn 模型预测结果
classes 模型预测结果
rpn模型的loss
好像0=1+2 1 就是rpn分类loss 2就是rpn 梯度回归的loss
classes 模型的loss
好像0=1+2+3 1是class分类的loss 2是class梯度回归的loss 3是rpn预测的bbox的分类准确率(不确定)
明明loss书写只有两个输出,为啥rpn模型运行后会出现3个,而class模型输入是4个roi,输出应该(4,2)个loss才对啊,为啥只有4个 而且不是每个roi对应一个loss?