利用torchvision.models实现卷积神经网络的backbone

torchvision.models

简介

torchvision.models 定义了用于处理不同任务的模型,包括图像分类、像素语义分割、对象检测、实例分割、人员关键点检测、视频分类和光流等。
在__init__文件下可以看到实现的网络列表,0.12版本实现的分类网络包括alexnet、convnext、resnet、vgg、squeezenet、inception、densenet、googlenet、mobilenet、mnasnet、shufflenet、efficientnet、regnet等,还实现了一些目标检测、特征提取、语义分割的经典网络。

格式框架

toechvision.models通过class类型保存网络结构及其所需功能模块,以dict类型保存与训练权重的下载地址,并定义一个函数用于外部调用(实现网络参数定义、预训练权重加载等功能),以下为resnet.py的格式框架。

__all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
   		   "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2"]

# download pre-trained model
model_urls = {
    "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
    "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
    "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
    "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
    "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
    "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
    "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
    "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
    "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
}

# structure definition
class BasicBlock(nn.Module):...
class Bottleneck(nn.Module):...
class ResNet(nn.Module):...

# definition & initialization
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...

如何改写实现backbone

在__init__文件下可以看到实现的网络列表,按功能分为分类网络、目标检测网络、
语义分割网络等。0.12版本实现的分类网络包括alexnet、convnext、resnet、vgg、squeezenet、inception、densenet、googlenet、mobilenet、mnasnet、shufflenet、efficientnet、regnet等。
目前的主流方法是利用分类网络的特征提取部分作为其他任务的特征提取网络,即删除分类网络最后的分类层(通常包括avgpool、flatten、full-connected层等),直接返回一张或多张特征图。

1. 新建目录存储backbone模型

在project目录下建立net文件夹存储backbone的py文件,不要在torchvision.models下的源文件里直接改写
例如,新建nets.GoogleNet.py文件,复制粘贴torchvision.models.googlenet.py的全部内容。

2. 删除分类层

根据最后定义的调用函数找到该网络的主干class,分析forward函数找到分类层,以resnet为例,最后三层分别对特征图进行自适应池化、一维化、全连接分类,属于网络的分类层,其输入即网络提取到的特征图,因此删除后三层。
分类层在__init__函数中的定义也要删除,否则网络初始化时依然会为其分配空间,浪费显存

class ResNet(nn.Module):
	def __init__(self, ...):
		super(ResNet, self).__init__()
		...
	
	def _forward_impl(self, x):
		x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)  # output feature map

		# classification layer, delete here
        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.fc(x)

        return x
		
	def forward(self, x):
		return self._forward_impl(x)

3. 输出改写

3-1. 选定特征层

如果你的目标检测网络只利用一层特征图进行目标预测,则对最后输出的特征图进行通道对齐即可。
如果目标检测网络使用了FPN等结构,需要多层特征图,则需要根据上述分类网络的结构进行分析,划分出三层或多层特征图进行输出。
特征层的选定有两种方法:

  • 输出通道数选择。以resnet为例,layer2、layer3、layer4的输出特征图通道数分别为512、1024、2048,与我的目标检测网络FPN结构的输入通道数匹配,因此选择这三层特征图构成金字塔
  • 网络结构选择。以GoogleNet为例,可以参考其aux_block的位置,分别输出inception4a、inception4d和inception5b的特征图构成金字塔。

3-2. 输出通道改写

选定的输出层通道数应与neck部分的输入通道数相匹配,在此建议修改neck结构来适应backbone的输出,而非修改backbone的结构,以便最大化利用预训练的权重,提高迁移学习的效果

4. 预训练权重加载

在train.py中实现预训练权重加载,以下为使用torchvision.models实现model.backbone的权重加载。
参数strict=False的作用是只加载结构匹配的参数,被删除的分类层、输出通道数被修改的特征提取层参数都不会被加载,因此建议在目标检测网络的neck部分修改通道参数,尽量不要改动backbone的结构,否则作为训练初期被冻结的backbone会导致网络训练效果不佳。

model_urls = {"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
			  "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
        	  "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
       		  "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
        	  "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", }
        	  
model.backbone.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir='./model_data/'), strict=False)

若训练中断,或选取一个epoch.pth作为起点重新训练,则需要加载整个网络的权重,加载代码如下

model_path = "logs/weights/Epoch_44.pth"

model_dict = model.state_dict()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_dict = torch.load(model_path, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值