ResNet
from ./fcos_core/modeling/backbone/resnet.py
"""
下述为ResNet及大量变体(采用cfg参数)
Example usage. Strings may be specified in the config file.
model = ResNet(
"StemWithFixedBatchNorm",
"BottleneckWithFixedBatchNorm",
"ResNet50StagesTo4",
)
OR:
model = ResNet(
"StemWithGN",
"BottleneckWithGN",
"ResNet50StagesTo4",
)
Custom implementations may be written in user code and hooked in via the
`register_*` functions.
"""
from collections import namedtuple
import torch
import torch. nn. functional as F
from torch import nn
from fcos_core. layers import FrozenBatchNorm2d
from fcos_core. layers import Conv2d
from fcos_core. layers import DFConv2d
from fcos_core. modeling. make_layers import group_norm
from fcos_core. utils. registry import Registry
StageSpec = namedtuple(
"StageSpec" ,
[
"index" ,
"block_count" ,
"return_features" ,
] ,
)
ResNet50StagesTo5 = tuple (
StageSpec( index= i, block_count= c, return_features= r)
for ( i, c, r) in ( ( 1 , 3 , False ) , ( 2 , 4 , False ) , ( 3 , 6 , False ) , ( 4 , 3 , True ) )
)
ResNet50StagesTo4 = tuple (
StageSpec( index= i, block_count= c, return_features= r)
for ( i, c, r) in ( ( 1 , 3 , False ) , ( 2 , 4 , False ) , ( 3 , 6 , True ) )
)
ResNet101StagesTo5 = tuple (
StageSpec( index= i, block_count= c, return_features= r)
for ( i, c, r) in ( ( 1 , 3 , False ) , ( 2 , 4 , False ) , ( 3 , 23 , False ) , ( 4 , 3 , True ) )
)
ResNet101StagesTo4 = tuple (
StageSpec( index= i, block_count= c, return_features= r)
for ( i, c, r) in ( ( 1 , 3 , False ) , ( 2 , 4 , False ) , ( 3 , 23 , True ) )
)
ResNet50FPNStagesTo5 = tuple (
StageSpec( index= i, block_count= c, return_features= r)
for ( i, c, r) in ( ( 1 , 3 , True ) , ( 2 , 4 , True ) , ( 3 , 6 , True ) , ( 4 , 3 , True ) )
)
ResNet101FPNStagesTo5 = tuple (
StageSpec( index= i, block_count= c, return_features= r)
for ( i, c, r) in ( ( 1 , 3 , True ) , ( 2 , 4 , True ) , ( 3 , 23 , True ) , ( 4 , 3 , True ) )
)
ResNet152FPNStagesTo5 = tuple (
StageSpec( index= i, block_count= c, return_features= r)
for ( i, c, r) in ( ( 1 , 3 , True ) , ( 2 , 8 , True ) , ( 3 , 36 , True ) , ( 4 , 3 , True ) )
)
class ResNet ( nn. Module) :
def __init__ ( self, cfg) :
super ( ResNet, self) . __init__( )
stem_module = _STEM_MODULES[ cfg. MODEL. RESNETS. STEM_FUNC]
stage_specs = _STAGE_SPECS[ cfg. MODEL. BACKBONE. CONV_BODY]
transformation_module = _TRANSFORMATION_MODULES[ cfg. MODEL. RESNETS. TRANS_FUNC]
self. stem = stem_module( cfg)
num_groups = cfg. MODEL. RESNETS. NUM_GROUPS
width_per_group = cfg. MODEL. RESNETS. WIDTH_PER_GROUP
in_channels = cfg. MODEL. RESNETS. STEM_OUT_CHANNELS
stage2_bottleneck_channels = num_groups * width_per_group
stage2_out_channels = cfg. MODEL. RESNETS. RES2_OUT_CHANNELS
self. stages = [ ]
self. return_features = { }
for stage_spec in stage_specs:
name = "layer" + str ( stage_spec. index)
stage2_relative_factor = 2 ** ( stage_spec. index - 1 )
bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor
out_channels = stage2_out_channels * stage2_relative_factor
stage_with_dcn = cfg. MODEL. RESNETS. STAGE_WITH_DCN[ stage_spec. index - 1 ]
module = _make_stage(
transformation_module,
in_channels,
bottleneck_channels,
out_channels,
stage_spec. block_count,
num_groups,
cfg. MODEL. RESNETS. STRIDE_IN_1X1,
first_stride= int ( stage_spec. index > 1 ) + 1 ,
dcn_config= {
"stage_with_dcn" : stage_with_dcn,
"with_modulated_dcn" : cfg. MODEL. RESNETS. WITH_MODULATED_DCN,
"deformable_groups" : cfg. MODEL. RESNETS. DEFORMABLE_GROUPS,
}
)
in_channels = out_channels
self. add_module( name, module)
self. stages. append( name)
self. return_features[ name] = stage_spec. return_features
self. _freeze_backbone( cfg. MODEL. BACKBONE. FREEZE_CONV_BODY_AT)
def _freeze_backbone ( self, freeze_at) :
if freeze_at < 0 :
return
for stage_index in range ( freeze_at) :
if stage_index == 0 :
m = self. stem
else :
m = getattr ( self, "layer" + str ( stage_index) )
for p in m. parameters( ) :
p. requires_grad = False
def forward ( self, x) :
outputs = [ ]
x = self. stem( x)
for stage_name in self. stages:
x = getattr ( self, stage_name) ( x)
if self. return_features[ stage_name] :
outputs. append( x)
return outputs
class ResNetHead ( nn. Module) :
def __init__ (
self,
block_module,
stages,
num_groups= 1 ,
width_per_group= 64 ,
stride_in_1x1= True ,
stride_init= None ,
res2_out_channels= 256 ,
dilation= 1 ,
dcn_config= None
) :
super ( ResNetHead, self) . __init__( )
stage2_relative_factor = 2 ** ( stages[ 0 ] . index - 1 )
stage2_bottleneck_channels = num_groups * width_per_group
out_channels = res2_out_channels * stage2_relative_factor
in_channels = out_channels // 2
bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor
block_module = _TRANSFORMATION_MODULES[ block_module]
self. stages = [ ]
stride = stride_init
for stage in stages:
name = "layer" + str ( stage. index)
if not stride:
stride = int ( stage. index > 1 ) + 1
module = _make_stage(
block_module,
in_channels,
bottleneck_channels,
out_channels,
stage. block_count,
num_groups,
stride_in_1x1,
first_stride= stride,
dilation= dilation,
dcn_config= dcn_config
)
stride = None
self. add_module( name, module)
self. stages. append( name)
self. out_channels = out_channels
def forward ( self, x) :
for stage in self. stages:
x = getattr ( self, stage) ( x)
return x
def _make_stage (
transformation_module,
in_channels,
bottleneck_channels,
out_channels,
block_count,
num_groups,
stride_in_1x1,
first_stride,
dilation= 1 ,
dcn_config= None
) :
blocks = [ ]
stride = first_stride
for _ in range ( block_count) :
blocks. append(
transformation_module(
in_channels,
bottleneck_channels,
out_channels,
num_groups,
stride_in_1x1,
stride,
dilation= dilation,
dcn_config= dcn_config
)
)
stride = 1
in_channels = out_channels
return nn. Sequential( * blocks)
class Bottleneck ( nn. Module) :
def __init__ (
self,
in_channels,
bottleneck_channels,
out_channels,
num_groups,
stride_in_1x1,
stride,
dilation,
norm_func,
dcn_config
) :
super ( Bottleneck, self) . __init__( )
self. downsample = None
if in_channels != out_channels:
down_stride = stride if dilation == 1 else 1
self. downsample = nn. Sequential(
Conv2d(
in_channels, out_channels,
kernel_size= 1 , stride= down_stride, bias= False
) ,
norm_func( out_channels) ,
)
for modules in [ self. downsample, ] :
for l in modules. modules( ) :
if isinstance ( l, Conv2d) :
nn. init. kaiming_uniform_( l. weight, a= 1 )
if dilation > 1 :
stride = 1
stride_1x1, stride_3x3 = ( stride, 1 ) if stride_in_1x1 else ( 1 , stride)
self. conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size= 1 ,
stride= stride_1x1,
bias= False ,
)
self. bn1 = norm_func( bottleneck_channels)
with_dcn = dcn_config. get( "stage_with_dcn" , False )
if with_dcn:
deformable_groups = dcn_config. get( "deformable_groups" , 1 )
with_modulated_dcn = dcn_config. get( "with_modulated_dcn" , False )
self. conv2 = DFConv2d(
bottleneck_channels,
bottleneck_channels,
with_modulated_dcn= with_modulated_dcn,
kernel_size= 3 ,
stride= stride_3x3,
groups= num_groups,
dilation= dilation,
deformable_groups= deformable_groups,
bias= False
)
else :
self. conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
kernel_size= 3 ,
stride= stride_3x3,
padding= dilation,
bias= False ,
groups= num_groups,
dilation= dilation
)
nn. init. kaiming_uniform_( self. conv2. weight, a= 1 )
self. bn2 = norm_func( bottleneck_channels)
self. conv3 = Conv2d(
bottleneck_channels, out_channels, kernel_size= 1 , bias= False
)
self. bn3 = norm_func( out_channels)
for l in [ self. conv1, self. conv3, ] :
nn. init. kaiming_uniform_( l. weight, a= 1 )
def forward ( self, x) :
identity = x
out = self. conv1( x)
out = self. bn1( out)
out = F. relu_( out)
out = self. conv2( out)
out = self. bn2( out)
out = F. relu_( out)
out0 = self. conv3( out)
out = self. bn3( out0)
if self. downsample is not None :
identity = self. downsample( x)
out += identity
out = F. relu_( out)
return out
class BaseStem ( nn. Module) :
def __init__ ( self, cfg, norm_func) :
super ( BaseStem, self) . __init__( )
out_channels = cfg. MODEL. RESNETS. STEM_OUT_CHANNELS
self. conv1 = Conv2d(
3 , out_channels, kernel_size= 7 , stride= 2 , padding= 3 , bias= False
)
self. bn1 = norm_func( out_channels)
for l in [ self. conv1, ] :
nn. init. kaiming_uniform_( l. weight, a= 1 )
def forward ( self, x) :
x = self. conv1( x)
x = self. bn1( x)
x = F. relu_( x)
x = F. max_pool2d( x, kernel_size= 3 , stride= 2 , padding= 1 )
return x
class BottleneckWithFixedBatchNorm ( Bottleneck) :
def __init__ (
self,
in_channels,
bottleneck_channels,
out_channels,
num_groups= 1 ,
stride_in_1x1= True ,
stride= 1 ,
dilation= 1 ,
dcn_config= None
) :
super ( BottleneckWithFixedBatchNorm, self) . __init__(
in_channels= in_channels,
bottleneck_channels= bottleneck_channels,
out_channels= out_channels,
num_groups= num_groups,
stride_in_1x1= stride_in_1x1,
stride= stride,
dilation= dilation,
norm_func= FrozenBatchNorm2d,
dcn_config= dcn_config
)
class StemWithFixedBatchNorm ( BaseStem) :
def __init__ ( self, cfg) :
super ( StemWithFixedBatchNorm, self) . __init__(
cfg, norm_func= FrozenBatchNorm2d
)
class BottleneckWithGN ( Bottleneck) :
def __init__ (
self,
in_channels,
bottleneck_channels,
out_channels,
num_groups= 1 ,
stride_in_1x1= True ,
stride= 1 ,
dilation= 1 ,
dcn_config= None
) :
super ( BottleneckWithGN, self) . __init__(
in_channels= in_channels,
bottleneck_channels= bottleneck_channels,
out_channels= out_channels,
num_groups= num_groups,
stride_in_1x1= stride_in_1x1,
stride= stride,
dilation= dilation,
norm_func= group_norm,
dcn_config= dcn_config
)
class StemWithGN ( BaseStem) :
def __init__ ( self, cfg) :
super ( StemWithGN, self) . __init__( cfg, norm_func= group_norm)
_TRANSFORMATION_MODULES = Registry( {
"BottleneckWithFixedBatchNorm" : BottleneckWithFixedBatchNorm,
"BottleneckWithGN" : BottleneckWithGN,
} )
_STEM_MODULES = Registry( {
"StemWithFixedBatchNorm" : StemWithFixedBatchNorm,
"StemWithGN" : StemWithGN,
} )
_STAGE_SPECS = Registry( {
"R-50-C4" : ResNet50StagesTo4,
"R-50-C5" : ResNet50StagesTo5,
"R-101-C4" : ResNet101StagesTo4,
"R-101-C5" : ResNet101StagesTo5,
"R-50-FPN" : ResNet50FPNStagesTo5,
"R-50-FPN-RETINANET" : ResNet50FPNStagesTo5,
"R-101-FPN" : ResNet101FPNStagesTo5,
"R-101-FPN-RETINANET" : ResNet101FPNStagesTo5,
"R-152-FPN" : ResNet152FPNStagesTo5,
} )