Retinaface中match函数的理解

对于match函数的理解:

def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
    #----------------------------------------------#
    #   计算所有的先验框和真实框的重合程度
    #----------------------------------------------#
    overlaps = jaccard(
        truths,
        point_form(priors)
    )
    #----------------------------------------------#
    #   所有真实框和先验框的最好重合程度
    #   best_prior_overlap [truth_box,1]
    #   best_prior_idx [truth_box,1]
    #----------------------------------------------#
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
    best_prior_idx.squeeze_(1)
    best_prior_overlap.squeeze_(1)

    #----------------------------------------------#
    #   所有先验框和真实框的最好重合程度
    #   best_truth_overlap [1,prior]
    #   best_truth_idx [1,prior]
    #----------------------------------------------#
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
    best_truth_idx.squeeze_(0)
    best_truth_overlap.squeeze_(0)

    #----------------------------------------------#
    #   用于保证每个真实框都至少有对应的一个先验框
    #----------------------------------------------#
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)
    # 对best_truth_idx内容进行设置
    for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j

    #----------------------------------------------#
    #   获取每一个先验框对应的真实框[num_priors,4]
    #----------------------------------------------#
    matches = truths[best_truth_idx]            
    # Shape: [num_priors] 此处为每一个anchor对应的label取出来
    conf = labels[best_truth_idx]        
    matches_landm = landms[best_truth_idx]
           
    #----------------------------------------------#
    #   如果重合程度小于threhold则认为是背景
    #----------------------------------------------#
    conf[best_truth_overlap < threshold] = 0    
    #----------------------------------------------#
    #   利用真实框和先验框进行编码
    #   编码后的结果就是网络应该有的预测结果
    #----------------------------------------------#
    loc = encode(matches, priors, variances)
    landm = encode_landm(matches_landm, priors, variances)

    #----------------------------------------------#
    #   [num_priors, 4]
    #----------------------------------------------#
    loc_t[idx] = loc
    #----------------------------------------------#
    #   [num_priors]
    #----------------------------------------------#
    conf_t[idx] = conf
    #----------------------------------------------#
    #   [num_priors, 10]
    #----------------------------------------------#
    landm_t[idx] = landm

 这里可以模拟一下:

import torch


truths=torch.randn(5,4) #[目标,每个边框坐标4个值]
print("truths:",truths)
labels=torch.Tensor([[1], [1],[1], [1],[1]])#因为定义的真实框的类别为1
print("labels:",labels)
overlaps = torch.rand(5, 10)#[目标,先验框的个数] iou 0-1
print("overlaps:",overlaps)

overlaps是真实框和先验框的iou矩阵。行为真实框,列为先验框。

其中truths代表的是真实框的坐标,为左上角和右下角的坐标

truths: tensor([[ 0.6497, -1.5030,  0.5384, -0.2993],
        [-0.3744,  0.7603, -0.0568,  1.5666],
        [ 0.2875, -1.1328, -1.5941, -0.1617],
        [-0.9408,  1.1018,  1.3028, -2.0684],
        [-0.2388, -1.0179,  0.6412,  1.0418]])
labels: tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.]])
overlaps: tensor([[0.5321, 0.7111, 0.2009, 0.0500, 0.6987, 0.1410, 0.5055, 0.4217, 0.5109,
         0.8002],
        [0.1247, 0.9751, 0.6667, 0.6870, 0.2429, 0.4588, 0.3245, 0.7241, 0.5880,
         0.6663],
        [0.8734, 0.7823, 0.8094, 0.1350, 0.4765, 0.2627, 0.4242, 0.9889, 0.8328,
         0.6587],
        [0.8570, 0.4075, 0.9889, 0.2479, 0.8518, 0.1139, 0.6657, 0.3636, 0.2662,
         0.6904],
        [0.5021, 0.8086, 0.7512, 0.6637, 0.2433, 0.9300, 0.8903, 0.3788, 0.4600,
         0.7378]])

这里有两点需要注意:

1、首先真实框要找到和它匹配程度最大的先验框

2、然后剩余的每个先验框找到自己对应的真实框

假如某一列低于threshold也就是说明,这个先验框没有很好匹配任何的真实框,很多先验框基本为0,没有和它匹配程度很大的真实框。

真实框要找到和它匹配程度最大的先验框

通过iou矩阵按照行的方向就能找到每个真实框对应的先验框。

best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
print(overlaps.max(1, keepdim=True))

max函数,1代表行的方向,0代表列方向。

torch.return_types.max(
values=tensor([[0.8002],
        [0.9751],
        [0.9889],
        [0.9889],
        [0.9300]]),
indices=tensor([[9],
        [1],
        [7],
        [2],
        [5]]))

还要在列方向找最大iou。为了后面让每个先验框找到对应于它们的真实框。

best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
print(overlaps.max(0, keepdim=True))

torch.return_types.max(
values=tensor([[0.8734, 0.9751, 0.9889, 0.6870, 0.8518, 0.9300, 0.8903, 0.9889, 0.8328,
         0.8002]]),
indices=tensor([[2, 1, 3, 1, 3, 4, 4, 2, 2, 0]]))

把不必要的维度删除。

best_prior_idx.squeeze_(1)  # [num_objects,1]->[num_objects]
best_prior_overlap.squeeze_(1) # [num_objects,1] -> [num_objects]


best_truth_idx.squeeze_(0)  # [1,num_priors] -> [num_priors]
best_truth_overlap.squeeze_(0) 

 保证了最好的iou保留了下来。

best_truth_overlap.index_fill_(0, best_prior_idx, 2)  #其中参数2,只要大于threshold就行
print("best_truth_overlap:",best_truth_overlap)

best_truth_overlap: tensor([0.8734, 2.0000, 2.0000, 0.6870, 0.8518, 2.0000, 0.8903, 2.0000, 0.8328,2.0000])

 因为上面保证每一个GT匹配它的都是具有最大IoU的Prior,也要同时修改best_truth_idx中每个prior相对应。

for j in range(best_prior_idx.size(0)):
    best_truth_idx[best_prior_idx[j]] = j


 print("best_truth_idx:",best_truth_idx)

best_truth_idx: tensor([2, 1, 3, 1, 3, 4, 4, 2, 2, 0])

 把每个先验框找到自己对应的真实框找到。

matches = truths[best_truth_idx]#每一个PriorBox对应的bbox取出来
print("matches:",matches)

matches: tensor([[ 0.2875, -1.1328, -1.5941, -0.1617],
        [-0.3744,  0.7603, -0.0568,  1.5666],
        [-0.9408,  1.1018,  1.3028, -2.0684],
        [-0.3744,  0.7603, -0.0568,  1.5666],
        [-0.9408,  1.1018,  1.3028, -2.0684],
        [-0.2388, -1.0179,  0.6412,  1.0418],
        [-0.2388, -1.0179,  0.6412,  1.0418],
        [ 0.2875, -1.1328, -1.5941, -0.1617],
        [ 0.2875, -1.1328, -1.5941, -0.1617],
        [ 0.6497, -1.5030,  0.5384, -0.2993]])

此处为每一个anchor对应的label取出来


print("best_truth_overlap:",best_truth_overlap)

conf = labels[best_truth_idx] 

print("conf:",conf)

best_truth_overlap: tensor([0.8734, 2.0000, 2.0000, 0.6870, 0.8518, 2.0000, 0.8903, 2.0000, 0.8328,2.0000])

conf: tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]]) 

 如果重合程度小于threhold则认为是背景

threshold=0.45
conf[best_truth_overlap < threshold] = 0  #过滤掉iou太低的,标记为background
print("conf:",conf)

conf: tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]]) 

Process finished with exit code 0
 

链接:

https://blog.csdn.net/weixin_41779359/article/details/111414567

https://blog.csdn.net/weixin_44791964/article/details/106872072#_9

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值