由于frcnn的网络结构主要是两个网络组成,损失函数分为四个部分。RPN分类损失:anchor是否为gt
RPN位置回归损失:anchor位置微调
ROI分类损失:ROI所属类别
ROI位置回归损失:继续对ROI位置微调
四个损失相加就是最后的损失,反向传播,更新参数。
RPN损失
gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(
at.tonumpy(bbox),
anchor,
img_size)
参数分别是gt坐标,处理过的anchor,图片的H和W信息。
AnchorTargetCreator类
class AnchorTargetCreator(object):
def __init__(self,
n_sample=256,
pos_iou_thresh=0.7, neg_iou_thresh=0.3,
pos_ratio=0.5):
self.n_sample = n_sample
self.pos_iou_thresh = pos_iou_thresh#大于0.7为正样本
self.neg_iou_thresh = neg_iou_thresh#小于0.3为负样本
self.pos_ratio = pos_ratio
def __call__(self, bbox, anchor, img_size):
img_H, img_W = img_size
n_anchor = len(anchor)