pytorch 对原有模型进行修改,创建新的backbone的方法

         在使用pytorch进行建模时,有时候,从原有的模型的基础上进行修改,效率更高,这么做的最大优点:一方面可以节约重新造轮子的时间,另一方面可以使用原来已经预训练好的权重,为模型的训练节省时间。有时候需要重新提取模型的某个特征层或多个特征曾,比如:在更换不同的backbone,添加FPN等使用场景。下面介绍几种模型修改的方式。

1 . 通过create_feature_extractor函数获得模型的中间多个节点

        create_feature_extractor函数会根据传入的参数寻找到模型正向传播的节点,然后删除该节点后面的正向传播节点,从中间截断原来的模型,只取前半段进行使用。这个函数真正的强大之处在于,可以取出模型中的任意多个中间节点作为输出,输出的方式以字典的形式。这个函数的用法在FPN中尤为常见。

        我们拿resnet18为例,resnet18模型结构如下,需要注意的是,我们要取出结构图种的conv4_x,对应torchvision resnet18中的layer3:

from torchvision.models.feature_extraction import create_feature_extractor
import torch
from torchvision.models import resnet18

net = resnet18(pretrained=False)
print(net)

model = create_feature_extractor(net,return_nodes={"layer2":"out2","layer3":"out3"})
# return_nodes={"layer3":节点名称:
#                  "out":输出是一个字典,这个是输出字典的键值}
print(model)
x = torch.rand(1,3,224,224)  # 生成一个输入tensor
y = model(x)
print(y["out2"].shape ,y["out3"].shape) 
# 查看输出 输出为14 * 14 这和layer3(对应表格中的conv4-x) 完全相同

2. 采用model.children() 的方式,获得模型的前半段

        采用model.children()函数可以将模型的子模块变成生成器,将生成器转化为列表,然后传入nn.Squential()进行模型的修改和截断。这种方式 比较方法1, 无法获得模型的多个输出层。但是在简单的使用场景中,更加简单高效。

import torch.nn as nn

class test(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Conv2d(3,3,1)
        self.layer2 = nn.Conv2d(3,3,2)
        self.layer3 = nn.Conv2d(3,3,3)
        self.layer4 = nn.Sequential(
            nn.Conv2d(3, 3, 4),
            nn.Conv2d(3, 3, 5)
        )
    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out


model = test()

for i in model.children():
    print(f" >>>  {i}")
model_new = nn.Sequential(*list(model.children())[:2])
for i in model_new.children():
    print(f" >>>  {i}")


#  >>>  Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
#  >>>  Conv2d(3, 3, kernel_size=(2, 2), stride=(1, 1))
#  >>>  Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
#  >>>  Sequential(
#   (0): Conv2d(3, 3, kernel_size=(4, 4), stride=(1, 1))
#   (1): Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1))
# )
#  >>>  Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
#  >>>  Conv2d(3, 3, kernel_size=(2, 2), stride=(1, 1))

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值