1、从模型调用开始看,首先看调用
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet50(pretrained=False, progress=True, **kwargs):
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
可以看到resnet至少需要两个显示参数,分别是block和layers。block是resnet18和resnet50中应用的两种不同结构,分别是basicblock和bottleneck。layers就是网络层数,也就是每个block的个数。
2、然后看网络结构
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
#参数比调用多几个,模型相较于最初发文章的时候有过更新
#block: basicblock或者bottleneck,后续会提到
#layers:每个block的个数,如resnet50, layers=[3,4,6,3],网络结构图中可以看到
#num_classes: 数据库类别数量
#zero_init_residual:其他论文中提到的一点小trick,残差参数为0
#groups:卷积层分组,应该是为了resnext扩展
#width_per_group:同上,此外还可以是wideresnet扩展
#replace_stride_with_dilation:空洞卷积,非原论文内容
#norm_layer:原论文用BN,此处设为可自定义
# 中间部分代码省略,只看模型搭建部分
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2