在从文件读入标注的数据时,会把物体数量向物体最多的那张图补齐,补齐的时候会添加进不少无效的框,最后计算的时候需要将这部分无效数据去除,添加的无效数据为(0,0,0,0),现在需要将这部分数据去掉
#把添加的无效数据去除
def gt_mask_from_gts(gts):
gt_stk = gts.view(-1, 4)
invalid_gt = torch.Tensor([0, 0, 0, 0])
if CAN_USE_GPU:
invalid_gt = invalid_gt.cuda()
gt_mask = torch.zeros(size=(gt_stk.shape[0], ))
gt_mask[gt_stk.eq(invalid_gt.view(1, 4)).sum(1) != 4] = 1
return gt_mask.view(gts.shape[0], gts.shape[1])
部分无效的iou也需要同样去掉
如图所示,红色框的中心是在第二个黑色框中,那么第一个黑色的框对应的anchor是不需要和红色的框计算iou的,那么只需要保留第二个anchor和红色框的iou,那么可以用对应的tensor来表示,如:0,0,0,0,0,1,1,1,1,1就表示取第二个框对应的anchor和红色框计算iou
def range_mask_from_gts(gts, w_n, anchor_num, cell_anchor_n