数据加载
with timer.env('Load Data'):
# img:(550, 550, 3)
# gt:(3, 5),3是3个物体,5是中心点,宽高,类别
# gt_mask: (3, 1080, 1920),1080*1920是加载的原图大小
# h, w:1080, 1920
# num_crowd:0
img, gt, gt_masks, h, w, num_crowd = dataset.pull_item(image_idx)
# pull_item在coco.py中
def pull_item():
return torch.from_numpy(img).permute(2, 0, 1), target, masks, height, width, num_crowds
forward
YOLACT将实例分割问题分解为两个并行的部分,分别产生 “prototype masks” 和 “mask coefficients”
- Protonet分支
使用全卷积网络(FCN)来生成一组“原型掩码”(prototype masks),该掩码不依赖于任何一个特定的实例,是共用的,对于每张输入图像预测k(32)个prototype masks
- Prediction Head分支
向目标检测分支(预测 anchor )添加额外的 head 为每一个 实例 / anchor 预测一系列 “掩模系数”(mask coefficients)。生成各候选框的类别 confidence、anchor 的 location 和 prototype mask 的 coefficient
- Mask Assembly
P:h×w×k的 prototype mask;C:n×k的mask系数矩阵
preds = net(batch)
# forward在yolact.py中
# x:torch.Size([1, 3, 550, 550])
def forward(self, x):
_, _, img_h, img_w = x.size()
cfg._tmp_img_h = img_h
cfg._tmp_img_w = img_w
with timer.env('backbone'):
#outs是一个tuple, len(outs)=4
#outs[0]:torch.Size([1, 256, 138, 138])
#outs[1]:torch.Size([1, 512, 69, 69])
#outs[2]:torch.Size([1, 1024, 35, 35])
#outs[3]:torch.Size([1, 2048, 18, 18])
outs = self.backbone(x)
if cfg.fpn is not None:
with timer.env('fpn'):
#cfg.backbone.selected_layers=[1, 2, 3]
outs = [outs[i] for i in cfg.backbone.selected_layers]
#经过fpn后,会产生5个输出,len(outs)=5
#outs[0]:torch.Size([1, 256, 69, 69])-----P3
#outs[1]:torch.Size([1, 256, 35, 35])-----P4
#outs[2]:torch.Size([1, 256, 18, 18])-----P5
#outs[3]:torch.Size([1, 256, 9, 9])-----P6
#outs[4]:torch.Size([1, 256, 5, 5])-----P7
outs = self.fpn(outs)
proto_out = None
if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
with timer.env('proto'):
# proto_x:对应outs[0],torch.Size([1, 256, 69, 69])-----P3
proto_x = x if self.proto_src is None else outs[self.proto_src]
if self.num_grids > 0:
grids = self.grid.repeat(proto_x.size(0), 1, 1, 1)
proto_x = torch.cat([proto_x, grids], dim=1)
#论文中认为32是最理想的
#proto_out: torch.Size([1, 32, 138, 138])
proto_out = self.proto_net(proto_x)
#论文中选择使用ReLU
#cfg.mask_proto_prototype_activation: activation_func.relu
proto_out = cfg.mask_proto_prototype_activation(proto_out)
if cfg.mask_proto_prototypes_as_features:
# Clone here because we don't want to permute this, though idk if contiguous makes this unnecessary
proto_downsampled = proto_out.clone()
if cfg.mask_proto_prototypes_as_features_no_grad:
proto_downsampled = proto_out.detach()
#proto_out: torch.Size([1, 138, 138, 32])
proto_out = proto_out.permute(0, 2, 3, 1).contiguous()
if cfg.mask_proto_bias:
bias_shape = [x for x in proto_out.size()]
bias_shape[-1] = 1
proto_out = torch.cat([proto_out, torch.ones(*bias_shape)], -1)
with timer.env('pred_heads'):
pred_outs = { 'loc': [], 'conf': [], 'mask': [], 'priors': [] }
if cfg.use_mask_scoring:
pred_outs['score'] = []
if cfg.use_instance_coeff:
pred_outs['inst'] = []
#self.selected_layers: [0, 1, 2, 3, 4]
#self.prediction_layers:ModuleList(
# (0): PredictionModule(
# (upfeature): Sequential(
# (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU(inplace=True)
# )
# (bbox_layer): Conv2d(256, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conf_layer): Conv2d(256, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (mask_layer): Conv2d(256, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# )
# (1): PredictionModule()
# (2): PredictionModule()
# (3): PredictionModule()
# (4): PredictionModule())
for idx, pred_layer in zip(self.selected_layers, self.prediction_layers):
pred_x = outs[idx]
if cfg.mask_type == mask_type.lincomb and cfg.mask_proto_prototypes_as_features:
# Scale the prototypes down to the current prediction layer's size and add it as inputs
proto_downsampled = F.interpolate(proto_downsampled, size=outs[idx].size()[2:], mode='bilinear', align_corners=False)
pred_x = torch.cat([pred_x, proto_downsampled], dim=1)
# idx=1,2,3,4的时候,会添加ModuleList的第0层
if cfg.share_prediction_module and pred_layer is not self.prediction_layers[0]:
pred_layer.parent = [self.prediction_layers[0]]
# ModuleList[0]的输出为:p={'loc':torch.Size([1, 14283, 4]), 'conf':torch.Size([1, 14283, 10]), 'mask':torch.Size([1, 14283, 32]), 'priors':torch.Size([14283, 4])}
# ModuleList[1]的输出为:p={'loc':torch.Size([1, 3675, 4]), 'conf':torch.Size([1, 3675, 10]), 'mask':torch.Size([1, 3675, 32]), 'priors':torch.Size([3675, 4])}
# ModuleList[2]的输出为:p={'loc':torch.Size([1, 972, 4]), 'conf':torch.Size([1, 972, 10]), 'mask':torch.Size([1, 972, 32]), 'priors':torch.Size([972, 4])}
# ModuleList[3]的输出为:p={'loc':torch.Size([1, 243, 4]), 'conf':torch.Size([1, 243, 10]), 'mask':torch.Size([1, 243, 32]), 'priors':torch.Size([243, 4])}
# ModuleList[4]的输出为:p={'loc':torch.Size([1, 75, 4]), 'conf':torch.Size([1, 75, 10]), 'mask':torch.Size([1, 75, 32]), 'priors':torch.Size([75, 4])}
p = pred_layer(pred_x)
for k, v in p.items():
pred_outs[k].append(v)
# 将同一个key的value拼接起来,如:pred_outs['loc']=torch.Size([1, 19248, 4])
for k, v in pred_outs.items():
pred_outs[k] = torch.cat(v, -2)
if proto_out is not None:
pred_outs['proto'] = proto_out
pred_outs['conf'] = F.softmax(pred_outs['conf'], -1)
# 函数在layers/functions/detection.py中
return self.detect(pred_outs, self)
layers/functions/detection.py文件
def __call__(self, predictions, net):
loc_data = predictions['loc']
conf_data = predictions['conf']
mask_data = predictions['mask']
prior_data = predictions['priors']
proto_data = predictions['proto'] if 'proto' in predictions else None
inst_data = predictions['inst'] if 'inst' in predictions else None
out = []
with timer.env('Detect'):
batch_size = loc_data.size(0)
num_priors = prior_data.size(0)
conf_preds = conf_data.view(batch_size, num_priors, self.num_classes).transpose(2, 1).contiguous()
for batch_idx in range(batch_size):
decoded_boxes = decode(loc_data[batch_idx], prior_data)
result = self.detect(batch_idx, conf_preds, decoded_boxes, mask_data, inst_data)
if result is not None and proto_data is not None:
result['proto'] = proto_data[batch_idx]
out.append({'detection': result, 'net': net})
return out
def decode(loc, priors, use_yolo_regressors:bool=False):
if use_yolo_regressors:
# Decoded boxes in center-size notation
boxes = torch.cat((
loc[:, :2] + priors[:, :2],
priors[:, 2:] * torch.exp(loc[:, 2:])
), 1)
boxes = point_form(boxes)
else:
variances = [0.1, 0.2]
boxes = torch.cat((
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def detect(self, batch_idx, conf_preds, decoded_boxes, mask_data, inst_data):
#torch.Size([1, 10, 19248])---torch.Size([9, 19248])---torch.Size([19248])
cur_scores = conf_preds[batch_idx, 1:, :]
conf_scores, _ = torch.max(cur_scores, dim=0)
# keep:torch.Size([19248])
keep = (conf_scores > self.conf_thresh)
#scores: torch.Size([9, 12]),12是可变的,与keep中的True有关
#boxes: torch.Size([12, 4])
#masks: torch.Size([12, 32])
scores = cur_scores[:, keep]
boxes = decoded_boxes[keep, :]
masks = mask_data[batch_idx, keep, :]
if self.use_fast_nms:
if self.use_cross_class_nms:
boxes, masks, classes, scores = self.cc_fast_nms(boxes, masks, scores, self.nms_thresh, self.top_k)
else:
boxes, masks, classes, scores = self.fast_nms(boxes, masks, scores, self.nms_thresh, self.top_k)
# box: torch.Size([36, 4])
# mask: torch.Size([36, 32])
# class: torch.Size([36])
# score: torch.Size([36])
return {'box': boxes, 'mask': masks, 'class': classes, 'score': scores}
postprocess
def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
visualize_lincomb=False, crop_masks=True, score_threshold=0):
dets = det_output[batch_idx]
net = dets['net']
dets = dets['detection']
if score_threshold > 0:
keep = dets['score'] > score_threshold
for k in dets:
if k != 'proto':
dets[k] = dets[k][keep]
if dets['score'].size(0) == 0:
return [torch.Tensor()] * 4
classes = dets['class']
boxes = dets['box']
scores = dets['score']
masks = dets['mask']
if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
proto_data = dets['proto']
if cfg.mask_proto_debug:
np.save('scripts/proto.npy', proto_data.cpu().numpy())
if visualize_lincomb:
display_lincomb(proto_data, masks)
# 对应论文中的Mask Assembly,@表示矩阵乘法
# proto_data: torch.Size([138, 138, 32])
# masks.t: torch.Size([32, 6])
# masks: torch.Size([138, 138, 6])
# 'mask_proto_mask_activation': activation_func.sigmoid
masks = proto_data @ masks.t()
masks = cfg.mask_proto_mask_activation(masks)
if crop_masks:
masks = crop(masks, boxes)
masks = masks.permute(2, 0, 1).contiguous()
if cfg.use_maskiou:
with timer.env('maskiou_net'):
with torch.no_grad():
maskiou_p = net.maskiou_net(masks.unsqueeze(1))
maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1)
if cfg.rescore_mask:
if cfg.rescore_bbox:
scores = scores * maskiou_p
else:
scores = [scores, scores * maskiou_p]
masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0)
masks.gt_(0.5)
boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=False)
boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=False)
boxes = boxes.long()
if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch:
# Upscale masks
full_masks = torch.zeros(masks.size(0), h, w)
for jdx in range(masks.size(0)):
x1, y1, x2, y2 = boxes[jdx, :]
mask_w = x2 - x1
mask_h = y2 - y1
# Just in case
if mask_w * mask_h <= 0 or mask_w < 0:
continue
mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size)
mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False)
mask = mask.gt(0.5).float()
full_masks[jdx, y1:y2, x1:x2] = mask
masks = full_masks
return classes, scores, boxes, masks
Loss
def forward(self, net, predictions, targets, masks, num_crowds):
loc_data = predictions['loc']
conf_data = predictions['conf']
mask_data = predictions['mask']
priors = predictions['priors']
if cfg.mask_type == mask_type.lincomb:
proto_data = predictions['proto']
score_data = predictions['score'] if cfg.use_mask_scoring else None
inst_data = predictions['inst'] if cfg.use_instance_coeff else None
labels = [None] * len(targets) # Used in sem segm loss
batch_size = loc_data.size(0)
num_priors = priors.size(0)
num_classes = self.num_classes
# Match priors (default boxes) and ground truth boxes
# These tensors will be created with the same device as loc_data
loc_t = loc_data.new(batch_size, num_priors, 4)
gt_box_t = loc_data.new(batch_size, num_priors, 4)
conf_t = loc_data.new(batch_size, num_priors).long()
idx_t = loc_data.new(batch_size, num_priors).long()
if cfg.use_class_existence_loss:
class_existence_t = loc_data.new(batch_size, num_classes-1)
for idx in range(batch_size):
truths = targets[idx][:, :-1].data
labels[idx] = targets[idx][:, -1].data.long()
if cfg.use_class_existence_loss:
# Construct a one-hot vector for each object and collapse it into an existence vector with max
# Also it's fine to include the crowd annotations here
class_existence_t[idx, :] = torch.eye(num_classes-1, device=conf_t.get_device())[labels[idx]].max(dim=0)[0]
# Split the crowd annotations because they come bundled in
cur_crowds = num_crowds[idx]
if cur_crowds > 0:
split = lambda x: (x[-cur_crowds:], x[:-cur_crowds])
crowd_boxes, truths = split(truths)
# We don't use the crowd labels or masks
_, labels[idx] = split(labels[idx])
_, masks[idx] = split(masks[idx])
else:
crowd_boxes = None
match(self.pos_threshold, self.neg_threshold,
truths, priors.data, labels[idx], crowd_boxes,
loc_t, conf_t, idx_t, idx, loc_data[idx])
gt_box_t[idx, :, :] = truths[idx_t[idx]]
# wrap targets
loc_t = Variable(loc_t, requires_grad=False)
conf_t = Variable(conf_t, requires_grad=False)
idx_t = Variable(idx_t, requires_grad=False)
pos = conf_t > 0
num_pos = pos.sum(dim=1, keepdim=True)
# Shape: [batch,num_priors,4]
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
losses = {}
# Localization Loss (Smooth L1)
if cfg.train_boxes:
loc_p = loc_data[pos_idx].view(-1, 4)
loc_t = loc_t[pos_idx].view(-1, 4)
losses['B'] = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') * cfg.bbox_alpha
if cfg.train_masks:
if cfg.mask_type == mask_type.direct:
if cfg.use_gt_bboxes:
pos_masks = []
for idx in range(batch_size):
pos_masks.append(masks[idx][idx_t[idx, pos[idx]]])
masks_t = torch.cat(pos_masks, 0)
masks_p = mask_data[pos, :].view(-1, cfg.mask_dim)
losses['M'] = F.binary_cross_entropy(torch.clamp(masks_p, 0, 1), masks_t, reduction='sum') * cfg.mask_alpha
else:
losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks)
elif cfg.mask_type == mask_type.lincomb:
ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels)
if cfg.use_maskiou:
loss, maskiou_targets = ret
else:
loss = ret
losses.update(loss)
if cfg.mask_proto_loss is not None:
if cfg.mask_proto_loss == 'l1':
losses['P'] = torch.mean(torch.abs(proto_data)) / self.l1_expected_area * self.l1_alpha
elif cfg.mask_proto_loss == 'disj':
losses['P'] = -torch.mean(torch.max(F.log_softmax(proto_data, dim=-1), dim=-1)[0])
# Confidence loss
if cfg.use_focal_loss:
if cfg.use_sigmoid_focal_loss:
losses['C'] = self.focal_conf_sigmoid_loss(conf_data, conf_t)
elif cfg.use_objectness_score:
losses['C'] = self.focal_conf_objectness_loss(conf_data, conf_t)
else:
losses['C'] = self.focal_conf_loss(conf_data, conf_t)
else:
if cfg.use_objectness_score:
losses['C'] = self.conf_objectness_loss(conf_data, conf_t, batch_size, loc_p, loc_t, priors)
else:
losses['C'] = self.ohem_conf_loss(conf_data, conf_t, pos, batch_size)
# Mask IoU Loss
if cfg.use_maskiou and maskiou_targets is not None:
losses['I'] = self.mask_iou_loss(net, maskiou_targets)
# These losses also don't depend on anchors
if cfg.use_class_existence_loss:
losses['E'] = self.class_existence_loss(predictions['classes'], class_existence_t)
if cfg.use_semantic_segmentation_loss:
losses['S'] = self.semantic_segmentation_loss(predictions['segm'], masks, labels)
# Divide all losses by the number of positives.
# Don't do it for loss[P] because that doesn't depend on the anchors.
total_num_pos = num_pos.data.sum().float()
for k in losses:
if k not in ('P', 'E', 'S'):
losses[k] /= total_num_pos
else:
losses[k] /= batch_size
# Loss Key:
# - B: Box Localization Loss
# - C: Class Confidence Loss
# - M: Mask Loss
# - P: Prototype Loss
# - D: Coefficient Diversity Loss
# - E: Class Existence Loss
# - S: Semantic Segmentation Loss
return losses
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'):
# proto_data: torch.Size([8, 138, 138, 32])
mask_h = proto_data.size(1)
mask_w = proto_data.size(2)
process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop
if cfg.mask_proto_remove_empty_masks:
# Make sure to store a copy of this because we edit it to get rid of all-zero masks
pos = pos.clone()
loss_m = 0
loss_d = 0 # Coefficient diversity loss
maskiou_t_list = []
maskiou_net_input_list = []
label_t_list = []
# mask_data: torch.Size([8, 19248, 32])
for idx in range(mask_data.size(0)):
with torch.no_grad():
# masks[0]: torch.Size([33, 550, 550])
# downsampled_masks: torch.Size([33, 138, 138])
downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w),
mode=interpolation_mode, align_corners=False).squeeze(0)
downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous()
if cfg.mask_proto_binarize_downsampled_gt:
# torch.gt(a,b)函数比较a中元素大于(这里是严格大于)b中对应元素,大于则为1,不大于则为0
downsampled_masks = downsampled_masks.gt(0.5).float()
# cur_pos: torch.Size([19248])
cur_pos = pos[idx]
# pos_idx_t: torch.Size([91])
pos_idx_t = idx_t[idx, cur_pos]
if process_gt_bboxes:
# Note: this is in point-form
if cfg.mask_proto_crop_with_pred_box:
pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos]
else:
# pos_gt_box_t: torch.Size([8, 19248, 4])
pos_gt_box_t = gt_box_t[idx, cur_pos]
# proto_masks: torch.Size([138, 138, 32])
proto_masks = proto_data[idx]
# proto_coef: torch.Size([91, 32])
proto_coef = mask_data[idx, cur_pos, :]
# If we have over the allowed number of masks, select a random sample
old_num_pos = proto_coef.size(0)
if old_num_pos > cfg.masks_to_train:
perm = torch.randperm(proto_coef.size(0))
select = perm[:cfg.masks_to_train]
proto_coef = proto_coef[select, :]
pos_idx_t = pos_idx_t[select]
if process_gt_bboxes:
pos_gt_box_t = pos_gt_box_t[select, :]
if cfg.use_mask_scoring:
mask_scores = mask_scores[select, :]
num_pos = proto_coef.size(0)
# mask_t: torch.Size([138, 138, 91])
mask_t = downsampled_masks[:, :, pos_idx_t]
# label_t: torch.Size([91])
label_t = labels[idx][pos_idx_t]
# pred_masks: torch.Size([138, 138, 91])
# cfg.mask_proto_mask_activatio: activation_func.sigmoid
pred_masks = proto_masks @ proto_coef.t()
pred_masks = cfg.mask_proto_mask_activation(pred_masks)
if cfg.mask_proto_crop:
pred_masks = crop(pred_masks, pos_gt_box_t)
if cfg.mask_proto_mask_activation == activation_func.sigmoid:
pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='none')
else:
pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none')
if cfg.mask_proto_normalize_emulate_roi_pooling:
# weight = 138*138 = 19044
weight = mask_h * mask_w if cfg.mask_proto_crop else 1
pos_gt_csize = center_size(pos_gt_box_t)
gt_box_width = pos_gt_csize[:, 2] * mask_w
gt_box_height = pos_gt_csize[:, 3] * mask_h
pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height * weight
# If the number of masks were limited scale the loss accordingly
if old_num_pos > num_pos:
pre_loss *= old_num_pos / num_pos
loss_m += torch.sum(pre_loss)
if cfg.use_maskiou:
if cfg.discard_mask_area > 0:
gt_mask_area = torch.sum(mask_t, dim=(0, 1))
select = gt_mask_area > cfg.discard_mask_area
if torch.sum(select) < 1:
continue
pos_gt_box_t = pos_gt_box_t[select, :]
pred_masks = pred_masks[:, :, select]
mask_t = mask_t[:, :, select]
label_t = label_t[select]
maskiou_net_input = pred_masks.permute(2, 0, 1).contiguous().unsqueeze(1)
pred_masks = pred_masks.gt(0.5).float()
maskiou_t = self._mask_iou(pred_masks, mask_t)
maskiou_net_input_list.append(maskiou_net_input)
maskiou_t_list.append(maskiou_t)
label_t_list.append(label_t)
losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}
if cfg.mask_proto_coeff_diversity_loss:
losses['D'] = loss_d
if cfg.use_maskiou:
# discard_mask_area discarded every mask in the batch, so nothing to do here
if len(maskiou_t_list) == 0:
return losses, None
maskiou_t = torch.cat(maskiou_t_list)
label_t = torch.cat(label_t_list)
maskiou_net_input = torch.cat(maskiou_net_input_list)
num_samples = maskiou_t.size(0)
if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
perm = torch.randperm(num_samples)
select = perm[:cfg.masks_to_train]
maskiou_t = maskiou_t[select]
label_t = label_t[select]
maskiou_net_input = maskiou_net_input[select]
return losses, [maskiou_net_input, maskiou_t, label_t]
return losses
模型训练
python train.py --config=yolact_base_config --resume=weights/yolact_base_10_32100.pth --start_iter=-1
模型评估
python eval.py --trained_model=weights/yolact_base_54_800000.pth
检测图片
python3 eval.py --trained_model=weights/yolact_base_cityscapes.pth --score_threshold=0.15 --top_k=15 --image=leftImg8bit.png:test.png