首先为了尊重他人贡献,本博客参考链接为:https://blog.csdn.net/VictoriaW/article/details/72821329 (感谢知识传递者)
楼主使用预训练模型的初衷是:目标识别检测框架在不同的计算机视觉应用中性能的提升很大一部分原因是由于骨干网络的西能提升,为了使用更好的Backbone获取feature map,所以有必要学习如何使用更好的预训练模型。
套路:在pytorch库中,使用预训练模型可以分为两种方式:第一,完全加载网络结构和checkpoint的权重;第二,首先构建相同结构的网络,然后加载预训练模型的checkpoint的权重state。一般而言,使用后者是一个通用操作。
pytorch中的torchvision.models模块中存放了较常用的backbone:vgg11,vgg11_bn,vgg13, vgg13_bn, vgg16,vgg16_bn,vgg19,vgg19_bn; resnet18, resnet34, resnet50, resnet101, resnet152; densnet121, densnet162, densnet169, densnet201等等(不一一列举)
第一步,构建想用使用的预训练网络的结构
可以自己构建,可以直接调用(即从pytorch.org...)中去下载
import torchvision.models as models # 导入相应模块
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
第二步,如果需要可以微调网络结构和筛选参数
vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)