def compute_loss(p, targets, model):
device = targets.device
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
tcls, tbox, indices, anchors = build_targets(p, targets, model)
h = model.hyp
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)
cp, cn = smooth_BCE(eps=0.0)
g = h['fl_gamma']
if g > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
nt = 0
no = len(p)
balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1]
for i, pi in enumerate(p):
b, a, gj, gi = indices[i]
tobj = torch.zeros_like(pi[..., 0], device=device)
n = b.shape[0]
if n:
nt += n
ps = pi[b, a, gj, gi]
pxy = ps[:, :2].sigmoid() * 2. - 0.5
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1).to(device)
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)
lbox += (1.0 - iou).mean()
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype)
if model.nc > 1:
t = torch.full_like(ps[:, 5:], cn, device=device)
t[range(n), tcls[i]] = cp
lcls += BCEcls(ps[:, 5:], t)
lobj += BCEobj(pi[..., 4], tobj) * balance[i]
s = 3 / no
lbox *= h['box'] * s
lobj *= h['obj'] * s * (1.4 if no == 4 else 1.)
lcls *= h['cls'] * s
bs = tobj.shape[0]
loss = lbox + lobj + lcls
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
def build_targets(p, targets, model):
det = model.module.model[-1] if is_parallel(model) else model.model[-1]
na, nt = det.na, targets.shape[0]
tcls, tbox, indices, anch = [], [], [], []
gain = torch.ones(7, device=targets.device)
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)
g = 0.5
off = torch.tensor([[0, 0],
[1, 0], [0, 1], [-1, 0], [0, -1],
], device=targets.device).float() * g
for i in range(det.nl):
anchors = det.anchors[i]
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]
t = targets * gain
if nt:
r = t[:, :, 4:6] / anchors[:, None]
j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t']
t = t[j]
gxy = t[:, 2:4]
gxi = gain[[2, 3]] - gxy
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
l, m = ((gxi % 1. < g) & (gxi > 1.)).T
j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j]
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
else:
t = targets[0]
offsets = 0
b, c = t[:, :2].long().T
gxy = t[:, 2:4]
gwh = t[:, 4:6]
gij = (gxy - offsets).long()
gi, gj = gij.T
a = t[:, 6].long()
indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1)))
tbox.append(torch.cat((gxy - gij, gwh), 1))
anch.append(anchors[a])
tcls.append(c)
return tcls, tbox, indices, anch