BlendMask代码详解
以下代码出自官方AdelaiDet,添加了注释方便阅读理解,BlendMask代码主要包含以下三个文件:
- blendmask.py
- blender.py
- basis_module.py
AdelaiDet/adet/modeling/blendmask/blendmask.py
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from torch import nn
from detectron2.structures import ImageList
from detectron2.modeling.postprocessing import detector_postprocess, sem_seg_postprocess
from detectron2.modeling.proposal_generator import build_proposal_generator
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.meta_arch.panoptic_fpn import combine_semantic_and_instance_outputs
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.semantic_seg import build_sem_seg_head
from .blender import build_blender
from .basis_module import build_basis_module
__all__ = ["BlendMask"]
# 调试用
import pdb
@META_ARCH_REGISTRY.register()
class BlendMask(nn.Module):
"""
Main class for BlendMask architectures (see https://arxiv.org/abd/1901.02446).
"""
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.instance_loss_weight = cfg.MODEL.BLENDMASK.INSTANCE_LOSS_WEIGHT
self.backbone = build_backbone(cfg) # 骨干网络
self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape())
self.blender = build_blender(cfg)
self.basis_module = build_basis_module(cfg, self.backbone.output_shape())
# options when combining instance & semantic outputs
self.combine_on = cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED
if self.combine_on:
self.panoptic_module = build_sem_seg_head(cfg, self.backbone.output_shape())
self.combine_overlap_threshold = cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH
self.combine_stuff_area_limit = cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT
self.combine_instances_confidence_threshold = (
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH)
# build top module
# 下面的内容在你训练输出的文件夹中的log.txt都能查到
in_channels = cfg.MODEL.FPN.OUT_CHANNELS # FPN为256
num_bases = cfg.MODEL.BASIS_MODULE.NUM_BASES # 26
attn_size = cfg.MODEL.BLENDMASK.ATTN_SIZE # 14
attn_len = num_bases * attn_size * attn_size # K*M*M
self.top_layer = nn.Conv2d(
in_channels, attn_len,
kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.top_layer.weight, std=0.01)
torch.nn.init.constant_(self.top_layer.bias, 0)
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.to(self.device)
#pdb.set_trace()
def forward(self, batched_inputs):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
image: Tensor, image in (C, H, W) format.
instances: Instances
sem_seg: semantic segmentation ground truth.
Other information that's included in the original dicts, such as:
"height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
Returns:
list[dict]: each dict is the results for one image. The dict
contains the following keys:
"instances": see :meth:`GeneralizedRCNN.forward` for its format.
"sem_seg": see :meth:`SemanticSegmentor.forward` for its format.
"panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`.
See the return value of
:func:`combine_semantic_and_instance_outputs` for its format.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [self.normalizer(x) for x in images] # 正则化
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
features = self.backbone(images.tensor)
"""
(Pbd) features.keys()
dict_keys(['p3', 'p4', 'p5', 'p6', 'p7'])
"""
#pdb.set_trace()
if self.combine_on:
if "sem_seg" in batched_inputs[0]:
gt_sem = [x["sem_seg"].to(self.device) for x in batched_inputs]
gt_sem = ImageList.from_tensors(
gt_sem, self.backbone.size_divisibility, self.panoptic_module.ignore_value
).tensor
else:
gt_sem = None
sem_seg_results, sem_seg_losses = self.panoptic_module(features, gt_sem)
if "basis_sem" in batched_inputs[0]:
basis_sem = [x["basis_sem"].to(self.device) for x in batched_inputs]
basis_sem = ImageList.from_tensors(
basis_sem, self.backbone.size_divisibility, 0).tensor
else:
basis_sem = None
basis_out, basis_losses = self.basis_module(features, basis_sem)
"""
(Pdb) basis_losses
{'loss_basis_sem': tensor(1.3870, device='cuda:0', grad_fn=<MulBackward0>)}
"""
#pdb.set_trace()
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
else:
gt_instances = None
# 对应fcos_outputs.py 的444行 self.top_layer不参与fcos原本的分支以及loss计算,只是多加了一个维度的变换。 256 --> 784
"""
(Pdb) proposal_losses
{'loss_fcos_cls': tensor(1.2077, device='cuda:0', grad_fn=<MulBackward0>),
'loss_fcos_loc': tensor(0.9512, device='cuda:0', grad_fn=<DivBackward0>),
'loss_fcos_ctr': tensor(0.7081, device='cuda:0', grad_fn=<DivBackward0>)}
"""
proposals, proposal_losses = self.proposal_generator(
images, features, gt_instances, self.top_layer)
#pdb.set_trace()
"""
(Pdb) detector_losses
{'loss_mask': tensor(0.6993, device='cuda:0', grad_fn=<DivBackward0>)}
"""
detector_results, detector_losses = self.blender( # 调用了__call__方法
basis_out["bases"], proposals, gt_instances)
#pdb.set_trace()
if self.training:
losses = {}
losses.update(basis_losses)
losses.update({k: v * self.instance_loss_weight for k, v in detector_losses.items()})
losses.update(proposal_losses)
if self.combine_on:
losses.update(sem_seg_losses)
return losses
processed_results = []
#pdb.set_trace()
for i, (detector_result, input_per_image, image_size) in enumerate(zip(
detector_results, batched_inputs, images.image_sizes)):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
detector_r = detector_postprocess(detector_result, height, width)
processed_result = {"instances": detector_r}
if self.combine_on:
sem_seg_r = sem_seg_postprocess(
sem_seg_results[i], image_size, height, width)
processed_result["sem_seg"] = sem_seg_r
if "seg_thing_out" in basis_out:
seg_thing_r = sem_seg_postprocess(
basis_out["seg_thing_out"], image_size, height, width)
processed_result["sem_thing_seg"] = seg_thing_r
if self.basis_module.visualize:
processed_result["bases"] = basis_out["bases"]
processed_results.append(processed_result)
if self.combine_on:
panoptic_r = combine_semantic_and_instance_outputs(
detector_r,
sem_seg_r.argmax(dim=0),
self.combine_overlap_threshold,
self.combine_stuff_area_limit,
self.combine_instances_confidence_threshold)
processed_results[-1]["panoptic_seg"] = panoptic_r
#pdb.set_trace()
#pdb.set_trace()
return processed_results
"""
(Pdb) batched_inputs[0]
{'file_name': '/mnt/big_disk/big_disk_1/zhuzhibo/download/AdelaiDet/datasets/handwritten_chinese_stroke_2021/train2021/HandwrittenChineseStroke_train_0000028734.jpg', 'height': 77, 'width': 47, 'image_id': 6569, 'image': tensor(
[[[252, 253, 254, ..., 255, 255, 255],
[252, 253, 254, ..., 255, 255, 255],
[252, 253, 255, ..., 255, 255, 255],
...,
[255, 235, 199, ..., 255, 255, 255],
[255, 247, 234, ..., 255, 255, 255],
[255, 254, 253, ..., 255, 255, 255]],
[[252, 253, 254, ..., 255, 255, 255],
[252, 253, 254, ..., 255, 255, 255],
[252, 253, 255, ..., 255, 255, 255],
...,
[255, 235, 199, ..., 255, 255, 255],
[255, 247, 234, ..., 255, 255, 255],
[255, 254, 253, ..., 255, 255, 255]],
[[252, 253, 254, ..., 255, 255, 255],
[252, 253, 254, ..., 255, 255, 255],
[252, 253, 255, ..., 255, 255, 255],
...,
[255, 235, 199, ..., 255, 255, 255],
[255, 247, 234, ..., 255, 255, 255],
[255, 254, 253, ..., 255, 255, 255]]], dtype=torch.uint8), 'instances': Instances(num_instances=9, image_height=144, image_width=88, fields=[gt_boxes: Boxes(tensor(
[[ 41.0402, 0.0000, 88.0000, 52.9219],
[ 48.8080, 26.3713, 75.8560, 82.4046],
[ 36.6902, 28.8540, 57.2133, 70.4563],
[ 0.0000, 5.1126, 44.1634, 66.7322],
[ 15.7172, 20.0092, 34.6867, 64.0943],
[ 56.4204, 80.8368, 76.4774, 127.8701],
[ 8.2601, 61.4402, 59.2329, 126.1632],
[ 29.2331, 80.9919, 52.5526, 129.2667],
[ 0.0000, 105.6644, 39.9688, 144.0000]])), gt_classes: tensor([ 3, 23, 17, 14, 3, 18, 8, 3, 17]), gt_masks: PolygonMasks(num_instances=9)]), 'basis_sem': tensor(
[[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
...,
[ 0, 0, 18, ..., 0, 0, 0],
[ 0, 0, 18, ..., 0, 0, 0],
[ 0, 0, 18, ..., 0, 0, 0]])}
"""
AdelaiDet/adet/modeling/blendmask/blender.py
import torch
from torch.nn import functional as F
from detectron2.layers import cat
from detectron2.modeling.poolers import ROIPooler
def build_blender(cfg):
return Blender(cfg)
# 调试用
import pdb
class Blender(object):
def __init__(self, cfg):
# fmt: off
# 以下内容在训练输出的文件夹中的log.txt都能找到
self.pooler_resolution = cfg.MODEL.BLENDMASK.BOTTOM_RESOLUTION # 56
sampling_ratio = cfg.MODEL.BLENDMASK.POOLER_SAMPLING_RATIO # 1
pooler_type = cfg.MODEL.BLENDMASK.POOLER_TYPE # 'ROIAlignV2'
pooler_scales = cfg.MODEL.BLENDMASK.POOLER_SCALES # (0.25,)
self.attn_size = cfg.MODEL.BLENDMASK.ATTN_SIZE # 14
self.top_interp = cfg.MODEL.BLENDMASK.TOP_INTERP # 'bililnear'
num_bases = cfg.MODEL.BASIS_MODULE.NUM_BASES # 4
# fmt: on
self.attn_len = num_bases * self.attn_size * self.attn_size # 14*14*4 = 784
self.pooler = ROIPooler(
output_size=self.pooler_resolution,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
pooler_type=pooler_type,
canonical_level=2)
"""
(Pdb) self.pooler
ROIPooler(
(level_poolers): ModuleList(
(0): ROIAlign(output_size=(56, 56), spatial_scale=0.25, sampling_ratio=1, aligned=True)
)
)
"""
#pdb.set_trace()
def __call__(self, bases, proposals, gt_instances):
"""
gt_instances表示batch_size上的gt实例
len(gt_instances) = 64 (因为我的batchsize为64)
gt_instances[0]就表示第0个batch上的图片的gt实例信息
(Pdb) gt_instances[0]
Instances(num_instances=5, image_height=144, image_width=112, fields=[gt_boxes: Boxes(tensor([[ 0.0000, 0.0000, 38.6394, 55.8094],
[ 14.5582, 6.7498, 86.0460, 82.6530],
[ 19.5723, 46.5602, 112.0000, 69.2312],
[ 34.3869, 64.0956, 108.8377, 144.0000],
[ 23.9028, 59.5271, 84.4506, 135.4303]], device='cuda:0')), gt_classes: tensor([ 3, 15, 1, 14, 3], device='cuda:0'), gt_masks: PolygonMasks(num_instances=5)])
"""
if gt_instances is not None:
# training
# reshape attns
dense_info = proposals["instances"] # 我这里是3473个instance
attns = dense_info.top_feats # attns.shape = [instances, 784]
pos_inds = dense_info.pos_inds # pos_inds.shape = [instances] :正样本的数量 pos_ind表示所有FPN层的像素点加起来的某些正样本的点
if pos_inds.numel() == 0:
return None, {"loss_mask": sum([x.sum() * 0 for x in attns]) + bases[0].sum() * 0}
gt_inds = dense_info.gt_inds # gt_inds.shape = [instances] :对应pos_inds位置上的类别
# 1.这里表示 RoIPool r_d = RoIPool(B, p_d)
rois = self.pooler(bases, [x.gt_boxes for x in gt_instances])
rois = rois[gt_inds] # rois.shape = [instances, num_bases, 56, 56],根据gt_inds上的值进行复制
pred_mask_logits = self.merge_bases(rois, attns) # pred_mask_logits.shape = [instances, 56*56]
# gen targets
gt_masks = []
# 遍历每个图片的实例信息
for instances_per_image in gt_instances:
if len(instances_per_image.gt_boxes.tensor) == 0:
continue
# crop到56*56 gt_mask_per_image.shape = [7(这张图包含的实例个数), 56, 56]
gt_mask_per_image = instances_per_image.gt_masks.crop_and_resize(
instances_per_image.gt_boxes.tensor, self.pooler_resolution
).to(device=pred_mask_logits.device)
gt_masks.append(gt_mask_per_image)
gt_masks = cat(gt_masks, dim=0) # gt_masks.shape = [484, 56, 56]
#pdb.set_trace()
gt_masks = gt_masks[gt_inds] # gt_masks.shape = [3473, 56, 56] 3473为总共的实例数
#pdb.set_trace()
N = gt_masks.size(0)
gt_masks = gt_masks.view(N, -1) # gt_masks.shape = [3473, 56*56]
gt_ctr = dense_info.gt_ctrs # gt_ctr.shape = [3473]
loss_denorm = proposals["loss_denorm"] # loss_denorm = ctrness_targets.sum()
# mask BCE loss [3473, 56*56]
# F.binary_cross_entropy_with_logits等价于torch.nn.BCEWithLogitsLoss,reduction="none"表示逐个元素相加、reduction="sum"为所有元素求和
mask_losses = F.binary_cross_entropy_with_logits(
pred_mask_logits, gt_masks.to(dtype=torch.float32), reduction="none")
mask_loss = ((mask_losses.mean(dim=-1) * gt_ctr).sum()
/ loss_denorm)
#pdb.set_trace()
return None, {"loss_mask": mask_loss}
else:
# no proposals
total_instances = sum([len(x) for x in proposals])
if total_instances == 0:
# add empty pred_masks results
for box in proposals:
box.pred_masks = box.pred_classes.view(
-1, 1, self.pooler_resolution, self.pooler_resolution)
return proposals, {}
rois = self.pooler(bases, [x.pred_boxes for x in proposals])
attns = cat([x.top_feat for x in proposals], dim=0)
pred_mask_logits = self.merge_bases(rois, attns).sigmoid()
pred_mask_logits = pred_mask_logits.view(
-1, 1, self.pooler_resolution, self.pooler_resolution)
start_ind = 0
for box in proposals:
end_ind = start_ind + len(box)
box.pred_masks = pred_mask_logits[start_ind:end_ind]
start_ind = end_ind
return proposals, {}
# 融合部分
def merge_bases(self, rois, coeffs, location_to_inds=None):
# merge predictions
# 输入的coeffss = [instances=3473, attn_len=784]
N = coeffs.size(0)
#pdb.set_trace()
if location_to_inds is not None:
rois = rois[location_to_inds]
N, B, H, W = rois.size()
coeffs = coeffs.view(N, -1, self.attn_size, self.attn_size) # [instances, -1, M, M] --> [instances, nums_bases, 14, 14]
# 2. 对应 a'd = interpolate_(M x M) --> (R x R)(a_d) Sd = softmax(a'd)
# S_d = softmax(a'_d) 在通道上对每一个元素做softmax。此处也就是对于4个元素。 # [instances, 4, 14, 14] --> [instances, 4, 56, 56]
coeffs = F.interpolate(coeffs, (H, W),
mode=self.top_interp).softmax(dim=1)
# 3.对应 m_d = sum(s^k_d * r^k_d)
masks_preds = (rois * coeffs).sum(dim=1) # [instances, 56, 56]
#pdb.set_trace()
return masks_preds.view(N, -1) # [instances, 56*56]
AdelaiDet/adet/modeling/blendmask/basis_modules.py
from typing import Dict
from torch import nn
from torch.nn import functional as F
from detectron2.utils.registry import Registry
from detectron2.layers import ShapeSpec
from adet.layers import conv_with_kaiming_uniform
BASIS_MODULE_REGISTRY = Registry("BASIS_MODULE")
BASIS_MODULE_REGISTRY.__doc__ = """
Registry for basis module, which produces global bases from feature maps.
The registered object will be called with `obj(cfg, input_shape)`.
The call should return a `nn.Module` object.
"""
# 调试用
import pdb
def build_basis_module(cfg, input_shape):
name = cfg.MODEL.BASIS_MODULE.NAME # ProtoNet,具体内容见adet/config/defaults
return BASIS_MODULE_REGISTRY.get(name)(cfg, input_shape)
@BASIS_MODULE_REGISTRY.register()
class ProtoNet(nn.Module):
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
"""
(Pdb) input_shape
{
'p3': ShapeSpec(channels=256, height=None, width=None, stride=8),
'p4': ShapeSpec(channels=256, height=None, width=None, stride=16),
'p5': ShapeSpec(channels=256, height=None, width=None, stride=32),
'p6': ShapeSpec(channels=256, height=None, width=None, stride=64),
'p7': ShapeSpec(channels=256, height=None, width=None, stride=128)
}
"""
"""
TODO: support deconv and variable channel width
"""
# official protonet has a relu after each conv
super().__init__()
# fmt: off
mask_dim = cfg.MODEL.BASIS_MODULE.NUM_BASES # 4
planes = cfg.MODEL.BASIS_MODULE.CONVS_DIM # 128
self.in_features = cfg.MODEL.BASIS_MODULE.IN_FEATURES # ['p3', 'p4', 'p5']
self.loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON # True
norm = cfg.MODEL.BASIS_MODULE.NORM # BN
num_convs = cfg.MODEL.BASIS_MODULE.NUM_CONVS # 3
self.visualize = cfg.MODEL.BLENDMASK.VISUALIZE # False
# fmt: on
# {'p3':256, ... , 'p6':256}
feature_channels = {k: v.channels for k, v in input_shape.items()}
conv_block = conv_with_kaiming_uniform(norm, True) # conv relu bn
self.refine = nn.ModuleList()
for in_feature in self.in_features:
self.refine.append(conv_block(
feature_channels[in_feature], planes, 3, 1))
tower = []
for i in range(num_convs):
tower.append(
conv_block(planes, planes, 3, 1))
tower.append(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False))
tower.append(
conv_block(planes, planes, 3, 1))
tower.append(
nn.Conv2d(planes, mask_dim, 1))
self.add_module('tower', nn.Sequential(*tower))
if self.loss_on:
# fmt: off
self.common_stride = cfg.MODEL.BASIS_MODULE.COMMON_STRIDE # 8
num_classes = cfg.MODEL.BASIS_MODULE.NUM_CLASSES + 1 # 26
self.sem_loss_weight = cfg.MODEL.BASIS_MODULE.LOSS_WEIGHT # 0.3
# fmt: on
inplanes = feature_channels[self.in_features[0]] # 256
self.seg_head = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=3,
stride=1, padding=1, bias=False),
nn.BatchNorm2d(planes),
nn.ReLU(),
nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False),
nn.BatchNorm2d(planes),
nn.ReLU(),
nn.Conv2d(planes, num_classes, kernel_size=1,
stride=1))
#pdb.set_trace()
def forward(self, features, targets=None):
"""
:param features: len(features)=5
:param targets:[N, h, w]
"""
for i, f in enumerate(self.in_features):
if i == 0:
x = self.refine[i](features[f])
else:
x_p = self.refine[i](features[f])
x_p = F.interpolate(x_p, x.size()[2:], mode="bilinear", align_corners=False)
# x_p = aligned_bilinear(x_p, x.size(3) // x_p.size(3))
x = x + x_p # [N, 128, h, w]
pdb.set_trace()
outputs = {"bases": [self.tower(x)]}
losses = {}
# auxiliary thing semantic loss 辅助语义损失
if self.training and self.loss_on:
sem_out = self.seg_head(features[self.in_features[0]])
# resize target to reduce memory
gt_sem = targets.unsqueeze(1).float()
gt_sem = F.interpolate(
gt_sem, scale_factor=1 / self.common_stride)
seg_loss = F.cross_entropy(
sem_out, gt_sem.squeeze(1).long())
losses['loss_basis_sem'] = seg_loss * self.sem_loss_weight
elif self.visualize and hasattr(self, "seg_head"):
outputs["seg_thing_out"] = self.seg_head(features[self.in_features[0]])
pdb.set_trace()
return outputs, losses
参考
https://www.freesion.com/article/58821475331/
https://github.com/aim-uofa/AdelaiDet
欢迎来我的个人空间参观!ZzBoAYU HOME