一、加载预训练模型
加载方式有两种,主要是第二种对于模型finetune比较常用
1、加载框架已有的模型(如resnet等)
代码如下:
import torch
import torch.nn as nn
from torch.utils import model_zoo
import torchvision.models as models
model = models.resnet18()
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
model.load_state_dict(model_zoo.load_url(model_urls['resnet18]), strict=False)
# 其中主要是strict=False,假设你针对原resnet18模型添加了自己的层,那么这个strict=False就会只加载name相同的参数
2、加载预训练好的模型
代码如下:
model = EfficientNet1()
# model.load_state_dict(model_zoo.load_url('https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth'), strict=False)
model_dict = model.state_dict()
sd = torch.load('/root/.cache/torch/checkpoints/efficientnet-b0-355c32eb.pth')
pretrained_dict = {k:v for k, v in sd.items() if k in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
3、冻结某些层
代码如下:
# 这里freeze了layer1之前的层(包括layer1),以及所有的bn层
for k,v in model.named_parameters():
if k.startswith('conv1') or k.startswith('layer1'):
v.requires_grad = False
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False