pytorch快速上手(4)-----如何导入torch中经典模型的部分结构

torchvision.model
model子包中包含了用于处理不同任务的经典模型的定义,包括:图像分类、像素级语义分割、对象检测、实例分割、人员关键点检测和视频分类。

  • 图像分类:
    在这里插入图片描述
  • 语义分割:
    在这里插入图片描述
  • 对象检测、实例分割和人员关键点检测:
    在这里插入图片描述
  • 视频分类:

ResNet 3D
ResNet Mixed Convolution
ResNet (2+1)D

1. 完整模型的导入(无预训练权重)

你可以通过调用构造函数来构造一个带有随机权重的模型(需要重新训练):

import torchvision.models as models

resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()

2.导入预训练好的模型

我们使用PyTorch torch.utils.model_zoo提供预训练的模型。这些可以通过传递pretrained=True来构造:

import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)

3.导入预训练模型的部分结构

有时候,我们想利用一些经典模型的架构作为特征提取器,仅想导入部分结构,而不是整个模型。

以作者自己训的一个多任务模型为例,例如仅利用resnet50 4f之前的结构作为backbone,以下代码中backbone后面各个任务分支代码省略
其他模型类似,查看网络结构和torch中模型定义,用到哪儿取到哪一层即可

class nmv_res50_centernet(nn.Module):
    def __init__(self, cls_num):
        super(nmv_res50_centernet, self).__init__()#super是调用父类里的方法,此处是调用父类Module的__init__()方法
        self.resnet50 = models.resnet50(pretrained=True)#使用resnet50作为backbone


        #其他使用到的层设置
        #=====属性分支======
        ......
        #========ReID分支=========
       ......
        #=========centernet检测分支==========
       ......
        #center
       ......
        #offset
       ......
        #wh
       ...... 


    def forward(self, x):
        #===========backbone============
        out = self.resnet50.conv1(x)
        out = self.resnet50.bn1(out)
        out = self.resnet50.relu(out)
        out = self.resnet50.maxpool(out)
        res2 = self.resnet50.layer1(out)
        res3 = self.resnet50.layer2(res2)
        res4 = self.resnet50.layer3(res3)#此处self.resnet50.layer3就对应resnet模型中残差模块block4部分

        #res4_d = res4.detach()#将该层梯度截断不反传
        #将特征图tensor按n通道均分为2份,若指定大小分割可用split
        fea_id, fea_attr_det = torch.chunk(res4, 2, dim=0)
		#后面接自己任务的分支即可
        #============attr================
       ......
        #===============ReID==============
       ......
        #===============detect=============
       .......
        return 自己需要的输出
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值