- build_targets 选取正样本,扩充正样本
def build_targets(pred, targets, model):
detect_layer = model.module.model[-1] \
if type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel) else \
model.model[-1]
number_anchor_per_pixel, number_targets = detect_layer.number_anchor_per_pixel, targets.shape[0]
targets_class, targets_box, indices, anchors = [], [], [], []
gain = torch.ones(6, device=targets.device)
offset_direction = torch.tensor(((-1, 0), (0, -1), (1, 0), (0, 1)), device=targets.device).float()
anchor_targets = \
torch.arange(number_anchor_per_pixel).reshape((number_anchor_per_pixel, 1)).repeat(1, number_targets)
offset_threshold = 0.5
for i in range(detect_layer.number_detection_layer):
anchor = detect_layer.anchors[i]
gain[2:] = torch.tensor(pred[i].shape)[[3, 2, 3, 2]]
targets, offsets = targets * gain, 0
if number_targets:
ratio_wh = targets[None, :, 4:] / anchor[:, None]
positive_anchor_to_target = torch.max(ratio_wh, 1. / ratio_wh).max(2)[0] < model.hyp['anchor_t']
anchor_targets, targets = anchor_targets[positive_anchor_to_target], \
targets.repeat(number_anchor_per_pixel, 1, 1)[positive_anchor_to_target]
grid_xy = targets[:, 2:4]
left_offset, up_offset = ((grid_xy % 1. < offset_threshold) & (grid_xy > 1.)).T
right_offset, down_offset = ((grid_xy % 1. > (1 - offset_threshold)) & (grid_xy < (gain[[2, 3]] - 1.))).T
anchor_targets = torch.cat((anchor_targets, anchor_targets[left_offset], anchor_targets[up_offset],
anchor_targets[right_offset], anchor_targets[down_offset]), 0)
targets = torch.cat((targets, targets[left_offset], targets[up_offset], targets[right_offset],
targets[down_offset]), 0)
original_position = torch.zeros_like(grid_xy)
offsets = torch.cat((original_position,
original_position[left_offset] + offset_direction[0],
original_position[up_offset] + offset_direction[1],
original_position[right_offset] + offset_direction[2],
original_position[down_offset] + offset_direction[3]), 0) * offset_threshold
grid_wh = targets[:, 4:]
grid_xy = targets[:, 2:4]
grid_ij = (targets[:, 2:4] + offsets).long()
grid_i, grid_j = grid_ij.T
targets_class.append(targets[:, 1].long().T)
targets_box.append(torch.cat((grid_xy - grid_ij, grid_wh), 1))
indices.append((targets[:, 0].long().T, anchor_targets, grid_j, grid_i))
anchors.append(anchor[anchor_targets])
return targets_class, targets_box, indices, anchors
- compute_loss 计算损失
def compute_loss(pred, targets, model):
device = targets.device
ft = torch.cuda.FloatTensor if pred[0].is_cuda else torch.Tensor()
loss_box, loss_cls, loss_obj = ft([0]).to(device), ft([0]).to(device), ft([0]).to(device)
targets_class, targets_box, indices, anchors = build_targets(pred, targets, model)
reduction = 'mean'
BCE_cls = torch.nn.BCEWithLogitsLoss(pos_weight=ft([model.hyp['cls_pw']]), reduction=reduction).to(device)
BCE_obj = torch.nn.BCEWithLogitsLoss(pos_weight=ft([model.hyp['obj_pw']]), reduction=reduction).to(device)
class_positive, class_negative = smooth_BCE(eps=0.0)
balance = [4.0, 1.0, 0.4] if len(pred) == 3 else [4.0, 1.0, 0.4, 0.1]
for i, pred_i in enumerate(pred):
positive_img_index, positive_anch_index, positive_grid_y, positive_grid_x = indices[i]
number_targets = positive_img_index.shape[0]
targets_obj = torch.zeros_like(pred_i[..., 4]).to(device)
if number_targets:
pred_positive = pred_i[positive_img_index, positive_anch_index, positive_grid_y, positive_grid_x]
pred_xy = pred_positive[:, 0:2].sigmoid() * 2 - 0.5
pred_wh = (pred_positive[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pred_box = torch.cat((pred_xy, pred_wh), 1).to(device)
iou = bbox_iou(pred_box.t(), targets_box[i], x1y1x2y2=False, iou_type='MIOU-C')
loss_box += (1.0 - iou).sum() if reduction == 'sum' else (1.0 - iou).mean()
if model.number_class > 1:
targets_cls = torch.full_like(pred_positive[:, 5:], class_negative).to(device)
targets_cls[range(number_targets), targets_class[i]] = class_positive
pre_cls = pred_positive[:, 5:]
loss_cls += BCE_cls(pre_cls, targets_cls)
targets_obj[positive_img_index, positive_anch_index, positive_grid_y, positive_grid_x] =\
(1.0 - model.iou_ratio) + model.iou_ratio * iou.detach().clamp(0).type(targets_obj.dtype)
pred_obj = pred_i[..., 4]
loss_obj += BCE_obj(pred_obj, targets_obj) * balance[i]
s = 3 / len(pred)
batch_size = targets_obj.shape[0]
loss_box *= model.hyp['giou'] * s
loss_cls *= model.hyp['cls'] * s
loss_obj *= model.hyp['obj'] * s * (1.4 if np == 4 else 1.)
loss = loss_box + loss_cls + loss_obj
return loss * batch_size, torch.cat((loss_box, loss_obj, loss_cls, loss)).detach()