代码库:https://github.com/ZiweiWangTHU/BiDet
源代码(SSD训练为例): ssd/train_bidet_ssd.py
if PRIOR_LOSS_WEIGHT != 0.: # 如果先验损失的权重不为0
loss_count = 0. # 初始化计数器
detect_result = net.module.detect_prior.forward(
loc_data, # localization preds
net.module.softmax(conf_data), # confidence preds
priors, # default boxes
gt_class
) # [batch, classes, top_k, 5 (score, (y1, x1, y2, x2))] # 计算检测结果
num_classes = detect_result.size(1) # 获取类别数目
# skip j = 0, because it's the background class
for j in range(1, num_classes): # 遍历每个类别,跳过背景类
all_dets = detect_result[:, j, :, :] # [batch, top_k, 5] # 获取每个批次每个类别的检测结果
all_mask = all_dets[:, :, :1].gt(0.).expand_as(all_dets) # [batch, top_k, 5] # 获取每个批次每个类别的有效检测结果的掩码
for batch_idx in range(batch_size): # 遍历每个批次
# skip non-existed class
if not (gt_class[batch_idx] == j - 1).any(): # 如果这个批次没有这个类别的真实目标,跳过
continue
dets = torch.masked_select(all_dets[batch_idx], all_mask[batch_idx]).view(-1, 5) # [num, 5] # 根据掩码选择出有效的检测结果
if dets.size(0) == 0: # 如果没有有效的检测结果,跳过
continue
# if pred num == gt num, skip
if dets.size(0) <= ((gt_class[batch_idx] == j - 1).sum().detach().cpu().item()): # 如果预测的目标数量小于等于真实的目标数量,跳过
continue
scores = dets[:, 0] # [num] # 获取检测结果中的得分部分
scores_sum = scores.sum().item() # no grad # 计算得分之和,不需要梯度
scores = scores / scores_sum # normalization # 对得分进行归一化处理
log_scores = log_func(scores) # 计算得分的对数
gt_num = (gt_class[batch_idx] == j - 1).sum().detach().cpu().item() # 计算真实目标数量
loss_p += (-1. * log_scores.sum() / float(gt_num)) # 计算先验损失,并累加到总损失中,注意要除以真实目标数量进行归一化
loss_count += 1. # 更新计数器
loss_p /= (loss_count + 1e-6) # 计算平均先验损失,注意要加上一个很小的数避免除零错误
loss_p *= PRIOR_LOSS_WEIGHT # 将平均先验损失乘以权重系数
这一段代码是BiDet算法中的另一个重要部分,它实现了先验损失(Prior Loss,PL)的功能。PL的目的是在训练过程中增加分类分支的随机性,从而提高二值化网络的分类精度。
具体来说,这段代码的作用是计算每个类别的预测得分和真实标签之间的交叉熵损失,然后对每个类别进行归一化,使得每个类别的损失与其真实目标数量成反比。这样做的好处是可以避免二值化网络过拟合到一个固定的分类值,而是让它能够适应不同的分类情况。
这段代码中,detect_result是根据定位数据、置信度数据、先验框和真实标签计算出来的检测结果,它包含了每个类别的得分和边界框坐标²。num_classes是类别数目。all_dets是每个批次每个类别的检测结果,all_mask是用来过滤掉得分为0的检测结果的掩码。dets是根据掩码选择出来的有效检测结果,scores是检测结果中的得分部分,scores_sum是得分之和,log_scores是得分的对数。gt_num是真实目标数量。loss_p是先验损失,它等于所有批次所有类别的负对数得分之和除以真实目标数量之和。loss_count是用来计算平均损失的计数器。