pytorch中-pretrain模型-引用与修改

在深度学习的网络中,我们发现使用pretrain网络有利于整个的网络初始搭建,已训练好的模型参数可以提高网络的训练速度,甚至在一定程度上提高网络的效果。

Pytorch作为深度学习对新手最为有好的框架,有很多方便易用的包。其中torchvision就是视觉方面不可或缺的包。其中包含三个方面: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。

在日常使用中,我们也经常面临pretrain网络的一些问题,在此进行统一总结。

1、直接加载/自定义位置加载预训练模型

1.1、完全使用预训练模型

import torchvision.models as models
 
#resnet
model = models.ResNet(pretrained=True)
model = models.resnet18(pretrained=True)
model = models.resnet34(pretrained=True)
model = models.resnet50(pretrained=True)
 
#vgg
model = models.VGG(pretrained=True)
model = models.vgg11(pretrained=True)
model = models.vgg16(pretrained=True)

1.2 完全使用训练模型,但是换个位置加载权重;常见的位置包括自行从在本地加载以及url网络下载。

def siggraph17(pretrained=True):
    model = SIGGRAPHGenerator()
    if(pretrained):
        import torch.utils.model_zoo as model_zoo
        model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
    return model
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
    if pretrained:
        aux_loss = True
    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
    if pretrained:
        # arch = arch_type + '_' + backbone + '_coco'
        # model_url = model_urls[arch]
        # if model_url is None:
        #     raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        # else:
        path = os.path.abspath(os.path.dirname(__file__))
        real_checkpoint = os.path.join(path, 'deeplabv3_resnet101_coco-586e9e4e.pth')
        state_dict = torch.load(real_checkpoint)
        model.load_state_dict(state_dict)
    return model

1.3 仅仅使用网络模型结构,不需要参数

model =torchvision.models.resnet50(pretrained=False)
2、加载部分预训练模型,修改网络结构
class FeatureVGG(nn.Module):
    def __init__(self , model):
        super(FeatureVGG, self).__init__()
        
        # vgg part
        #removed = list(model.classifier.children())[:-1]
        #model.classifier = torch.nn.Sequential(*removed)
        # print(model)
        #self.vgg = model.cuda()
    
        # resnet50 part
        model = nn.Sequential(*list(model.children())[:-1])
        self.resnet = model.cuda()
        
    def forward(self, x):
        x = F.interpolate(x, size=(224, 224), mode='bicubic') 
        #x = self.vgg(x)
        x = self.resnet(x)
        return x

vgg19 = models.resnet50(pretrained=True) #resnet50 vgg19_bn
model_global = FeatureVGG(vgg19)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值