class mask_rcnn_fcn_head_v1upXconvs_gn_adp_ff(nn.Module):
"""v1upXconvs design: X * (conv 3x3), convT 2x2, with GroupNorm"""
def __init__(self, dim_in, roi_xform_func, spatial_scale, num_convs):
super().__init__()
self.dim_in = dim_in
self.roi_xform = roi_xform_func
self.spatial_scale = spatial_scale
self.num_convs = num_convs
dilation = cfg.MRCNN.DILATION
dim_inner = cfg.MRCNN.DIM_REDUCED
self.dim_out = dim_inner
module_list = []
for i in range(2):
module_list.extend([
nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
nn.ReLU(inplace=True)
])
dim_in = dim_inner
self.conv_fcn = nn.Sequential(*module_list)
self.mask_conv1 = nn.ModuleList()
num_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1
for i in range(num_levels):
self.mask_conv1.append(nn.Sequential(
nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
nn.ReLU(inplace=True)
))
self.mask_conv4 = nn.Sequential(
nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
nn.ReLU(inplace=True))
self.mask_conv4_fc = nn.Sequential(
nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
nn.ReLU(inplace=True))
self.mask_conv5_fc = nn.Sequential(
nn.Conv2d(dim_in, int(dim_inner / 2), 3, 1, padding=1*dilation, dilation=dilation, bias=False),
nn.GroupNorm(net_utils.get_group_gn(dim_inner), int(dim_inner / 2), eps=cfg.GROUP_NORM.EPSILON),
nn.ReLU(inplace=True))
self.mask_fc = nn.Sequential(
nn.Linear(int(dim_inner / 2) * (cfg.MRCNN.ROI_XFORM_RESOLUTION) ** 2, cfg.MRCNN.RESOLUTION ** 2, bias=True),
nn.ReLU(inplace=True))
# upsample layer
self.upconv = nn.ConvTranspose2d(dim_inner, dim_inner, 2, 2, 0)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if cfg.MRCNN.CONV_INIT == 'GaussianFill':
init.normal_(m.weight, std=0.001)
elif cfg.MRCNN.CONV_INIT == 'MSRAFill':
mynn.init.MSRAFill(m.weight)
else:
raise ValueError
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.01)
init.constant_(m.bias, 0)
def detectron_weight_mapping(self):
mapping_to_detectron = {}
for i in range(self.num_convs):
mapping_to_detectron.update({
'conv_fcn.%d.weight' % (3*i): '_mask_fcn%d_w' % (i+1),
'conv_fcn.%d.weight' % (3*i+1): '_mask_fcn%d_gn_s' % (i+1),
'conv_fcn.%d.bias' % (3*i+1): '_mask_fcn%d_gn_b' % (i+1)
})
mapping_to_detectron.update({
'upconv.weight': 'conv5_mask_w',
'upconv.bias': 'conv5_mask_b'
})
return mapping_to_detectron, []
def forward(self, x, rpn_ret):
x = self.roi_xform(
x, rpn_ret,
blob_rois='mask_rois',
method=cfg.MRCNN.ROI_XFORM_METHOD,
resolution=cfg.MRCNN.ROI_XFORM_RESOLUTION,
spatial_scale=self.spatial_scale,
sampling_ratio=cfg.MRCNN.ROI_XFORM_SAMPLING_RATIO,
panet=True
)
for i in range(len(x)):
x[i] = self.mask_conv1[i](x[i])
for i in range(1, len(x)):
x[0] = torch.max(x[0], x[i])
x = x[0]
x = self.conv_fcn(x)
batch_size = x.size(0)
x_fcn = F.relu(self.upconv(self.mask_conv4(x)), inplace=True)
x_ff = self.mask_fc(self.mask_conv5_fc(self.mask_conv4_fc(x)).view(batch_size, -1))
return [x_fcn, x_ff]
PANET
最新推荐文章于 2024-04-08 22:32:33 发布