加载COCO预训练DeeplabV3
DeeplabV3 ResNet101
Pytorch可以直接加载用COCO预训练过的DeeplabV3模型,用于分割问题。模型在COCO train2017的一个子集上进行预训练,训练集包含20个Pascal VOC中的类别。
调用
对于ResNet101为backbone的DeeplabV3,可以直接使用如下API调用:
torchvision.models.segmentation.deeplabv3_resnet101
(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs)
torchvision.models.segmentation源码
API部分的源码,定义网络。源码见pytorch官网
接口的定义函数deeplabv3_resnet101
主要参数为:网络的结构(fcn或deeplabv3),主干网络(resnet50或resnet101)。
def deeplabv3_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
加载模型的函数_load_model
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
if pretrained:
aux_loss = True
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
创建用于分割的resnet函数_segm_resnet
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
return_layers = {
'layer4': 'out'}
if aux:
return_layers['layer3'] = 'aux'
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classif