前言
本文是对SSD-Pytorch源码中MultiBoxLoss文件的内容解析。
关键词:multiboxloss
、SSD
、pytorch
、目标检测
注:编者水平有限,如有谬误,欢迎指正。若要转载,请注明出处,谢谢。
联系方式:
邮箱:yue_zhan@yahoo.com
QQ:1156356625
正文
首先看类的注释说明:
class MultiBoxLoss(nn.Module):
"""SSD Weighted Loss Function
Compute Targets:
1) Produce Confidence Target Indices by matching ground truth boxes
with (default) 'priorboxes' that have jaccard index > threshold parameter
(default threshold: 0.5).
2) Produce localization target by 'encoding' variance into offsets of ground
truth boxes and their matched 'priorboxes'.
3) Hard negative mining to filter the excessive number of negative examples
that comes with using a large number of default bounding boxes.
(default negative:positive ratio 3:1)
Objective Loss:
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
weighted by α which is set to 1 by cross val.
Args:
c: class confidences,
l: predicted boxes,
g: ground truth boxes
N: number of matched default boxes
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
"""
注释分两部分:
- 计算流程
总共分三部分:(1)得到与ground truth相匹配的box(2)encoding将ground truth做成prior的格式,variance是放大梯度使模型训练加速(3)Hard negative mining优化分类效果 - loss公式
然后是初始化部分:
def __init__(self, num_classes, overlap_thresh, prior_for_matching,
bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
use_gpu=True):
super(MultiBoxLoss, self).__init__()
self.use_gpu = use_gpu
# 类别数一般为数据集类别数K+1(背景类)
self.num_classes = num_classes
# overlap的阈值,默认0.5
self.threshold = overlap_thresh
# 背景类对应的label编号,默认0
# 如果像更改背景类的label编号,比如定为最后一个类,需要同时对Dataloader的__init__()中label生成做修改
self.background_label = bkg_label
# 以下参数应该是调试的时候用的,forward里面没有用到
self.encode_target = encode_target
self.use_prior_for_matching = prior_for_matching
self.do_neg_mining = neg_mining
# 负正样本比例,默认3:1
self.negpos_ratio = neg_pos
self.neg_overlap = neg_overlap
# 梯度放缩,训练加速
self.variance = cfg['variance']
这个variance是模型训练的一个trick,在DPR中也有类似的操作,在算Smooth-L1 loss时,对输入x进行放缩以放大梯度,加速收敛。
之后是forward部分:
def forward(self, predictions, targets):
"""Multibox Loss
Args:
predictions (tuple): A tuple containing loc preds, conf preds,
and prior boxes from SSD net.
conf shape: torch.size(batch_size,num_priors,num_classes)
loc shape: torch.size(batch_size,num_priors,4)
priors shape: torch.size(num_priors,4)
targets (tensor): Ground truth boxes and labels for a batch,
shape: [batch_size,num_objs,5] (last idx is the label).
"""
loc_data, conf_data, priors = predictions
num = loc_data.size