在使用pytorch进行建模时,有时候,从原有的模型的基础上进行修改,效率更高,这么做的最大优点:一方面可以节约重新造轮子的时间,另一方面可以使用原来已经预训练好的权重,为模型的训练节省时间。有时候需要重新提取模型的某个特征层或多个特征曾,比如:在更换不同的backbone,添加FPN等使用场景。下面介绍几种模型修改的方式。
1 . 通过create_feature_extractor函数获得模型的中间多个节点
create_feature_extractor函数会根据传入的参数寻找到模型正向传播的节点,然后删除该节点后面的正向传播节点,从中间截断原来的模型,只取前半段进行使用。这个函数真正的强大之处在于,可以取出模型中的任意多个中间节点作为输出,输出的方式以字典的形式。这个函数的用法在FPN中尤为常见。
我们拿resnet18为例,resnet18模型结构如下,需要注意的是,我们要取出结构图种的conv4_x,对应torchvision resnet18中的layer3:
from torchvision.models.feature_extraction import create_feature_extractor
import torch
from torchvision.models import resnet18
net = resnet18(pretrained=False)
print(net)
model = create_feature_extractor(net,return_nodes={"layer2":"out2","layer3":"out3"})
# return_nodes={"layer3":节点名称:
# "out":输出是一个字典,这个是输出字典的键值}
print(model)
x = torch.rand(1,3,224,224) # 生成一个输入tensor
y = model(x)
print(y["out2"].shape ,y["out3"].shape)
# 查看输出 输出为14 * 14 这和layer3(对应表格中的conv4-x) 完全相同
2. 采用model.children() 的方式,获得模型的前半段
采用model.children()函数可以将模型的子模块变成生成器,将生成器转化为列表,然后传入nn.Squential()进行模型的修改和截断。这种方式 比较方法1, 无法获得模型的多个输出层。但是在简单的使用场景中,更加简单高效。
import torch.nn as nn
class test(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Conv2d(3,3,1)
self.layer2 = nn.Conv2d(3,3,2)
self.layer3 = nn.Conv2d(3,3,3)
self.layer4 = nn.Sequential(
nn.Conv2d(3, 3, 4),
nn.Conv2d(3, 3, 5)
)
def forward(self,x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out
model = test()
for i in model.children():
print(f" >>> {i}")
model_new = nn.Sequential(*list(model.children())[:2])
for i in model_new.children():
print(f" >>> {i}")
# >>> Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
# >>> Conv2d(3, 3, kernel_size=(2, 2), stride=(1, 1))
# >>> Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# >>> Sequential(
# (0): Conv2d(3, 3, kernel_size=(4, 4), stride=(1, 1))
# (1): Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1))
# )
# >>> Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
# >>> Conv2d(3, 3, kernel_size=(2, 2), stride=(1, 1))