2021SC@SDUSC
CenterNet之loss计算代码解析(接上文)
代码解析
来自train.py中第173行开始进行loss计算:
# 得到heat map, reg, wh 三个变量
hmap, regs, w_h_ = zip(*outputs)
regs = [
_tranpose_and_gather_feature(r, batch['inds']) for r in regs
]
w_h_ = [
_tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
]
# 分别计算loss
hmap_loss = _neg_loss(hmap, batch['hmap'])
reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])
# 进行loss加权,得到最终loss
loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss
上述transpose_and_gather_feature函数具体实现如下,主要功能是将ground truth中计算得到的对应中心点的值获取。