在深度学习的网络中,我们发现使用pretrain网络有利于整个的网络初始搭建,已训练好的模型参数可以提高网络的训练速度,甚至在一定程度上提高网络的效果。
Pytorch作为深度学习对新手最为有好的框架,有很多方便易用的包。其中torchvision就是视觉方面不可或缺的包。其中包含三个方面: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。
在日常使用中,我们也经常面临pretrain网络的一些问题,在此进行统一总结。
1、直接加载/自定义位置加载预训练模型
1.1、完全使用预训练模型
import torchvision.models as models
#resnet
model = models.ResNet(pretrained=True)
model = models.resnet18(pretrained=True)
model = models.resnet34(pretrained=True)
model = models.resnet50(pretrained=True)
#vgg
model = models.VGG(pretrained=True)
model = models.vgg11(pretrained=True)
model = models.vgg16(pretrained=True)
1.2 完全使用训练模型,但是换个位置加载权重;常见的位置包括自行从在本地加载以及url网络下载。
def siggraph17(pretrained=True):
model = SIGGRAPHGenerator()
if(pretrained):
import torch.utils.model_zoo as model_zoo
model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
return model
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
if pretrained:
aux_loss = True
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
# arch = arch_type + '_' + backbone + '_coco'
# model_url = model_urls[arch]
# if model_url is None:
# raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
# else:
path = os.path.abspath(os.path.dirname(__file__))
real_checkpoint = os.path.join(path, 'deeplabv3_resnet101_coco-586e9e4e.pth')
state_dict = torch.load(real_checkpoint)
model.load_state_dict(state_dict)
return model
1.3 仅仅使用网络模型结构,不需要参数
model =torchvision.models.resnet50(pretrained=False)
2、加载部分预训练模型,修改网络结构
class FeatureVGG(nn.Module):
def __init__(self , model):
super(FeatureVGG, self).__init__()
# vgg part
#removed = list(model.classifier.children())[:-1]
#model.classifier = torch.nn.Sequential(*removed)
# print(model)
#self.vgg = model.cuda()
# resnet50 part
model = nn.Sequential(*list(model.children())[:-1])
self.resnet = model.cuda()
def forward(self, x):
x = F.interpolate(x, size=(224, 224), mode='bicubic')
#x = self.vgg(x)
x = self.resnet(x)
return x
vgg19 = models.resnet50(pretrained=True) #resnet50 vgg19_bn
model_global = FeatureVGG(vgg19)