if opt.wh_weight > 0:
#dense_wh稠密的wh形式,即训练数据中wh的标注batch['dense_wh']的维度是batch*2*output_w*output_h而不是batch['wh']的维度self.max_objs*2
if opt.dense_wh:
mask_weight = batch['dense_wh_mask'].sum() + 1e-4
#batch['dense_wh_mask']存储的是在输出头“wh”(output_w*output_h*2)中有目标中心点的位置的掩码(与heatmap层对应);
#以保证output['wh'] * batch['dense_wh_mask']相乘后,预测的output['wh']只在有对象中心点的位置有值,而其它无对象中心点的位置置零
wh_loss += (
self.crit_wh(output['wh'] * batch['dense_wh_mask'],
batch['dense_wh'] * batch['dense_wh_mask']) /
mask_weight) / opt.num_stacks
elif opt.cat_spec_wh:
wh_loss += self.crit_wh(
output['wh'], batch['cat_spec_mask'],
batch['ind'], batch['cat_spec_wh']) / opt.num_stacks
else:
wh_loss += self.crit_reg(
output['wh'], batch['reg_mask'],
batch['ind'], batch['wh']) / opt.num_stacks
def forward(self, output, mask, ind, target):
#通过_tranpose_and_gather_feat以及def _gather_feat()函数得出我们预测的宽高,32*50(self.max_objs)*2
pred = _tranpose_and_gather_feat(output, ind)
#mask维度 :32*50(self.max_objs)---->32*50*2
mask = mask.unsqueeze(2).expand_as(pred).float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
def _tranpose_and_gather_feat(feat, ind):
#feat(即,预测的output['wh'])维度32*2*96*96----->32*96*96*2
feat = feat.permute(0, 2, 3, 1).contiguous()
#feat维度32*96*96*2----->32*9216*2
feat = feat.view(feat.size(0), -1, feat.size(3))
#根据ind取出feat中对应的元素; 因为不是dense_wh形式,训练数据中wh的标注batch['wh']的维度是self.max_objs*2,和预测的输出feat(output['wh'])的维度32*2*96*96不相符,
#没有办法进行计算求损失,所以需要根据ind(对象在heatmap图上的索引)取出feat中对应的元素,使其维度和batch['wh']一样,最后维度为32*50*2
feat = _gather_feat(feat, ind)
return feat
def _gather_feat(feat, ind, mask=None):
#dim = 2
dim = feat.size(2)
#ind维度 :32*50---->32*50*2
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
#从feat的第1个维度,按ind给出的索引提取元素
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat