SSD目标检测损失函数部分代码解析

SSD目标检测损失函数部分代码解析

SSD的loss计算,接下来就来逐一分析一下计算方法。
打开MultiBoxloss类。

	# 回归信息,置信度,先验框
    loc_data, conf_data, priors = predictions
    # 计算出batch_size
    num = loc_data.size(0)
    # 取出所有的先验框
    priors = priors[:loc_data.size(1), :]
    # 先验框的数量
    num_priors = (priors.size(0))
    # 创建一个tensor进行处理
    loc_t = torch.Tensor(num, num_priors, 4)
    conf_t = torch.LongTensor(num, num_priors)

    if self.use_gpu:
        loc_t = loc_t.cuda()
        conf_t = conf_t.cuda()
        priors = priors.cuda()

    for idx in range(num):
        # 获得框
        truths = targets[idx][:, :-1]
        # 获得标签
        labels = targets[idx][:, -1]
        # 获得先验框
        defaults = priors
        # 找到标签对应的先验框
        match(self.threshold, truths, defaults, self.variance, labels,
              loc_t, conf_t, idx)

这几部就是常规操作了,把模型卷积之后的结果和实际的标签都获取之后,准备计算loss。然后重点是这个match函数,看一下这个函数的内容。

def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
    # 计算所有的先验框和真实框的重合程度
    overlaps = jaccard(
        truths,
        point_form(priors)
    )
    # 所有真实框和先验框的最好重合程度
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
    best_prior_idx.squeeze_(1)
    best_prior_overlap.squeeze_(1)
    
    # 所有先验框和真实框的最好重合程度
    # [1,prior]
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
    best_truth_idx.squeeze_(0)
    best_truth_overlap.squeeze_(0)
    # 找到与真实框重合程度最好的先验框,用于保证每个真实框都要有对应的一个先验框
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)
    # 对best_truth_idx内容进行设置
    for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j
    # 找到每个先验框重合程度最好的真实框
    matches = truths[best_truth_idx]
    conf = labels[best_truth_idx] + 1
    
    # 如果重合程度小于threhold则认为是背景
    conf[best_truth_overlap < threshold] = 0
    # 偏移量学习 
    loc = encode(matches, priors, variances)
    loc_t[idx] = loc
    # 每个先验框的最优标签
    conf_t[idx] = conf

这部分函数虽然看似不长,但是真的是不怎么好理解,首先看第一步:

	# 计算所有的先验框和真实框的重合程度
    overlaps = jaccard(
        truths,
        point_form(priors))

计算所有先验框与真实框的重合度,传入两个参数,一个是真实的先验框,来自于标注的图像,另一个是net输出的先验框,也就是自己划分的先验框。然后再转到jaccard函数。

def jaccard(box_a, box_b):
    inter = intersect(box_a, box_b)
    # 计算先验框和真实框各自的面积
    area_a = ((box_a[:, 2]-box_a[:, 0]) *
              (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
    area_b = ((box_b[:, 2]-box_b[:, 0]) *
              (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
    # 求IOU
    union = area_a + area_b - inter
    return inter / union  # [A,B]

Jaccard函数主要目的是计算真实框与先验框重合部分的面积与相交的面积比,即A交B比上A并B,如图:即图中的C / ((A+B) - C)
在这里插入图片描述

既然知道了要算什么,那就看看代码是怎么实现的吧,jaccard中的area_a和area_b就是分别算了A和B的面积,而C的面积,好吧,又一个函数,intersect。

def intersect(box_a, box_b):
    A = box_a.size(0)
    B = box_b.size(0)
    max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
                       box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
    min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
                       box_b[:, :2].unsqueeze(0).expand(A, B, 2))
    inter = torch.clamp((max_xy - min_xy), min=0)

    # 计算先验框和所有真实框的重合面积
    return inter[:, :, 0] * inter[:, :, 1]

Intersect函数中,max_xy计算的是A和B两个左下角点中像素值较小的那一个,min_xy计算的是A和B两个左下角点中像素值较大的那一个,这样就可以得到C的左上角点和右下角点,inter中max_xy和min_xy相减,就得到了C的两条边长,然后返回两边长的乘积,就得到相交部分C的面积。然后再回到jaccard中就可以很清楚的知道那个返回值返回的就是一个面积比。再解释两行特殊的代码:

def point_form(boxes):
    return torch.cat((boxes[:, :2] - boxes[:, 2:]/2,     # xmin, ymin
                     boxes[:, :2] + boxes[:, 2:]/2), 1)  # xmax, ymax

这个是因为先验框本来是中点和宽高形式的,但是真实框是左上点和右下点的形式的,所以要转换成同样的形式。

box_a[:, 2:].unsqueeze(1).expand(A, B, 2),

这句代码的目的是把真实框的个数变的和先验框一样多,就是将真实框复制,这样就方便一步计算,不然就需要循环一步一步算,显然这是torch语法的方便之处。
接下来看看这个重合度是怎么使用的:

# 所有真实框和先验框的最好重合程度
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
best_prior_idx.squeeze_(1)
best_prior_overlap.squeeze_(1)

# 所有先验框和真实框的最好重合程度
# [1,prior]
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0)

这里上面是真实框和先验框的最好重合度,而下面是先验框和真实框的最好重合度,乍一看这不是一个事么,但是这确实不是一个事,举个栗子:
在这里插入图片描述

如图:行数代表真实框的个数,列数代表先验框的个数,显然一幅图中的真实框个数是比较少的,先验框是比较多的,先说真实框和先验框的最好重合度,其实就是所有的先验框与所有的真实框去比较,体现在代码中,就是真实框假设我有三个,那就找出三个与真实框重合度最好的先验框,也就是说如上图中的样子,在每一行中寻找一个最大值,而先验框和真实框的最好重合度则是要为每一个先验框都匹配一个最大值,所以要按列去每列中寻找一个最大值。

# 找到每个先验框重合程度最好的真实框
matches = truths[best_truth_idx]
conf = labels[best_truth_idx] + 1

# 如果重合程度小于threhold则认为是背景
conf[best_truth_overlap < threshold] = 0

这三行代码最终目的是为了得到重合度最好的先验框所在的位置
在这里插入图片描述

如图可以看到1所在的位置就是与真实框重合度最好的先验框,因为我只有一个真实框,所以也就只有一个1。

	# 偏移量学习 
    loc = encode(matches, priors, variances)

Match中的最后一行代码,调用了一个encode的函数。

	def encode(matched, priors, variances):
	    g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
	    g_cxcy /= (variances[0] * priors[:, 2:])
	    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
	    g_wh = torch.log(g_wh) / variances[1]

	    return torch.cat([g_cxcy, g_wh], 1)

这个函数内容很简单,就是简单的数学运算,前两行就是对先验框中心点位置的调整,后两行是对先验框宽高的调整,至于为什么这么算,这个应该是SSD开发者的经验和智慧了。

# 所有conf_t>0的地方,代表内部包含物体
pos = conf_t > 0
# 求和得到每一个图片内部有多少正样本
num_pos = pos.sum(dim=1, keepdim=True)   
# 计算回归loss
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
loc_p = loc_data[pos_idx].view(-1, 4)
loc_t = loc_t[pos_idx].view(-1, 4)
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)

这部分代码是计算先验框的回归loss,这里用了smooth_l1_loss的损失函数,pytorch中几个常用的损失函数具体可以自行到网上查。接下来看下分类的损失函数。

	 # 你可以把softmax函数看成一种接受任何数字并转换为概率分布的非线性方法
    # 获得每个框预测到真实框的类的概率
    loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

这一步就是交叉熵损失函数的计算方法,可以参考下面这个公式,这行代码就是实现了这个公式。目的是为了计算每一个框预测到真实框的概率。
在这里插入图片描述

    # 获得每一张图新的softmax的结果
    _, loss_idx = loss_c.sort(1, descending=True)
    _, idx_rank = loss_idx.sort(1)
    # 计算每一张图的正样本数量
    num_pos = pos.long().sum(1, keepdim=True)
    # 限制负样本数量
    num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)	-1)
    neg = idx_rank < num_neg.expand_as(idx_rank)

接下来为计算得到的概率进行排序,将概率最大的框放在最前面,然后获取到正样本的数量,再根据正样本的数量去限制负样本的数量,因为如果不限制的话负样本会远多于正样本,而一般负样本是正样本的三倍就够了。

	# 计算正样本的loss和负样本的loss
    pos_idx = pos.unsqueeze(2).expand_as(conf_data)
    neg_idx = neg.unsqueeze(2).expand_as(conf_data)
    
    conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_c	lasses)
    targets_weighted = conf_t[(pos+neg).gt(0)]
    loss_c = F.cross_entropy(conf_p, targets_weighted, size_average	=False)

最后这部分重点了解一下torch.gt,这里的conf_p中保存的是预测的每个框的内容的类别置信度,targets_weighted中保存的是框内的真实标签,然后计算loss值。

有问题可以加入我的群,QQ群号109530447

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值