链接:https://www.zhihu.com/question/42205480/answer/525212289
来源:知乎
现在过遍训练RPN的流程:
1、假设输入的图片是576* 960,经过VGG16生成卷积特征图,缩放16倍,得到36 * 60 *512的特征图,512是通道数
2、使用3*3*512的滑窗在特征图滑动(stride=1,padding=2),找到每个三维滑窗的中心在原图对应的像素位置,然后在原图上生成9个anchor,这9个anchor的中心重合且和三维滑窗对应的中心相同。共有36*60个滑窗,则共有36*60*9个原图上的锚 。3*3*512的三维滑窗足够表达228 pixels区域的特征。
3、生成anchor的同时根据ground truth的IoU值确定每个anchor是正样本还是负样本(是否是类)。为两种anchor分配正样本标签:(1)和一个ground truth的IoU值最大的那些anchor;(2)和任意一个ground truth的IoU大于或等于0.7的anchor。注意到单个ground truth可能分配正样本标签给多个anchor。第(2)个条件足以确定正样本,仍然使用第(1)个条件的原因是存在在第(2)个条件下没有正样本的情况。不是正样本的anchor中,那些和所有ground truth的IoU值小于0.3的被分配为负样本,其余那些既不是正样本也不是负样本的anchor,不作为训练样本,被忽略掉的。这样对于一个滑窗,原图产生的9个anchor,anchor可能被分配到正样本、负样本、被忽略样本的标签。然后对每个正样本anchor对应的ground truth,打上回归的标签,交给该anchor对应的回归器预测。anchor生成好以后,回归器对应的正样本回归标签 就确定好的。
4、从一张图片当中随机采样256个anchor,正负样本anchor的比例是1:1,如果一张图片中正样本anchor的数量少于128个,就减少负样本anchor的数量和正样本anchor数量相同。需要注意的是,这256个anchor中来自不同的3*3滑窗。
5、接着看网络的流程,mini-batch中的每个样本anchor对应的3*3*512的三维滑窗经过卷积(卷积核3*3*512)生成512维的向量,这512维的向量进入分类层、9个独立的回归层,输出一定是(4+2)* 9维,然后对正样本anchor对应的ground truth的回归层输出,使用 计算回归损失loss,然后还要计算其分类损失,负样本也要计算分类损失。被忽略的锚对应的输出loss=0。
个人理解:
anchor应该称为anchor box,它本质是从特征图返回来生成的的候选框,这就充分体现出anchor的本质含义(锚,一个固定点,固定的是anchor box)。
在anchor的区分为正样本、负样本和忽略样本部分有很多可以操作的空间,现在的操作是和groundtruth的IOU大于0.7为正样本,小于0.3为负样本,中间的为忽略样本。忽略样本的存在,很重要,因为有很多面积很小特征不明显的框,只能把它们忽略掉。
看了下程序,anchor根据阈值区分正负样本的代码如下:
def match(bboxes, gt, cfg, gt_ignores=None):
"""
Match roi to gt
Temporarily used tensors:
overlaps (FloatTensor): [N, M], ious of dt(N) with gt(M)
ignore_overlaps (FloatTensor): [N, K], ious of dt(N) with ignore regions(K)
Returns:
target (LongTensor): [N], matched gt index for each roi.
1. if a roi is positive, it's target is matched gt index (>=0)
2. if a roi is negative, it's target is -1,
3. if a roi isn't positive nor negative, it's target is -2;
"""
NEGATIVE_TARGET = -1
IGNORE_TARGET = -2
N = bboxes.shape[0]
M = gt.shape[0]
# check M > 0 for no-gt support
overlaps = bbox_iou_overlaps(bboxes, gt) if M > 0 else bboxes.new_zeros(N, 1)
ignore_overlaps = None
if gt_ignores is not None and gt_ignores.numel() > 0:
ignore_overlaps = bbox_iof_overlaps(bboxes, gt_ignores)
target = bboxes.new_full((N,), IGNORE_TARGET, dtype=torch.int64)
dt_to_gt_max, dt_to_gt_argmax = overlaps.max(dim=1)
# rule 1: negative if maxiou < negative_iou_thresh:
neg_mask = dt_to_gt_max < cfg['negative_iou_thresh']
target[neg_mask] = NEGATIVE_TARGET
# rule 2: positive if maxiou > pos_iou_thresh
pos_mask = dt_to_gt_max > cfg['positive_iou_thresh']
target[pos_mask] = dt_to_gt_argmax[pos_mask]
# rule 3: positive if a dt has highest iou with any gt
if cfg.get('allow_low_quality_match') and M > 0:
overlaps = overlaps.t() # IMPORTANT, for faster caculation
gt_to_dt_max, _ = overlaps.max(dim=1)
dt_gt_pairs = torch.nonzero((overlaps >= gt_to_dt_max[:, None] - 1e-3))
if dt_gt_pairs.numel() > 0:
lqm_dt_inds = dt_gt_pairs[:, 1]
target[lqm_dt_inds] = dt_to_gt_argmax[lqm_dt_inds]
pos_mask[lqm_dt_inds] = 1
# rule 4: dt has high iou with ignore regions may not supposed to be negative
if ignore_overlaps is not None and ignore_overlaps.numel() > 0:
dt_to_ig_max, _ = ignore_overlaps.max(dim=1)
ignored_dt_mask = dt_to_ig_max > cfg['ignore_iou_thresh']
# remove positives from ignored
ignored_dt_mask = (ignored_dt_mask ^ (ignored_dt_mask & pos_mask))
target[ignored_dt_mask] = IGNORE_TARGET
return target