import torch
# 构建神经网络所需的各种层(如线性层、卷积层、循环层等)和损失函数。
from torch import nn
# 在神经网络构建和训练中实现一些常见的操作
# 激活函数(ReLU、Sigmoid等)、池化(Pooling)、归一化(Normalization)
from torch.nn import functional as F
# torchvision包中的一个模块,包含了大量预训练的模型,如ResNet、VGG、Inception
from torchvision import models
# torchvision.ops.misc:
# 杂项操作,例如FrozenBatchNorm2d,这是一个冻结版本的批量归一化层
# FrozenBatchNorm2d
# 一个特殊的批量归一化层,它在训练时不会更新其参数
# 通常用于微调预训练模型,以避免对归一化层参数的过度拟合
from torchvision.ops.misc import FrozenBatchNorm2d
class Backbone(nn.Module):
def __init__(
self,
name: str,# ResNet的版本,如'resnet50'
pretrained: bool,# 是否使用预训练的权重
dilation: bool,# 是否使用扩张卷积
reduction: int,# 特征融合时的降维因子
swav: bool,# 是否使用SWAV(Swiss Army Vector)预训练权重
requires_grad: bool# 是否需要在训练中计算梯度
):
super(Backbone, self).__init__()# 调用基类的初始化方法
# 获取指定的ResNet模型,使用FrozenBatchNorm2d来冻结批量归一化层
resnet = getattr(models, name)(
# 指定是否使用扩张卷积替换步长
replace_stride_with_dilation=[False, False, dilation],
# 是否加载预训练权重
pretrained=pretrained,
# 使用冻结的批量归一化层
norm_layer=FrozenBatchNorm2d
)
# 保存ResNet模型
self.backbone = resnet
# 保存降维因子
self.reduction = reduction
# 如果使用SWAV预训练权重,加载并应用
if name == 'resnet50' and swav:
# 加载预训练模型的权重
checkpoint = torch.hub.load_state_dict_from_url(
'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar',
map_location="cpu"
)
# 清理状态字典的键
state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
# 加载状态字典到模型中
# 根据ResNet版本,设置不同层的通道数
self.backbone.load_state_dict(state_dict, strict=False)
# concatenation of layers 2, 3 and 4
self.num_channels = 896 if name in ['resnet18', 'resnet34'] else 3584
# 遍历模型的参数,设置是否需要计算梯度
for n, param in self.backbone.named_parameters():
# 如果参数不在layer2、layer3、layer4中,则不需要计算梯度
if 'layer2' not in n and 'layer3' not in n and 'layer4' not in n:
param.requires_grad_(False)
# 否则,根据requires_grad参数决定是否计算梯度
else:
param.requires_grad_(requires_grad)
def forward(self, x):
size = x.size(-2) // self.reduction, x.size(-1) // self.reduction
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = layer2 = self.backbone.layer2(x)
x = layer3 = self.backbone.layer3(x)
x = layer4 = self.backbone.layer4(x)
x = torch.cat([
F.interpolate(f, size=size, mode='bilinear', align_corners=True)
for f in [layer2, layer3, layer4]
], dim=1)
return x
功能解释:
Backbone
类接收多个参数来配置主干网络的行为,包括是否使用预训练权重、是否使用扩张卷积、是否使用SWAV预训练权重等。- 使用
getattr(models, name)
动态获取torchvision.models
中定义的指定名称的ResNet模型。 replace_stride_with_dilation
参数用于在网络的第三层中用扩张卷积替换步长,这有助于捕获更广泛的上下文信息。pretrained
参数控制是否加载在ImageNet数据集上预训练的权重。FrozenBatchNorm2d
用于冻结批量归一化层的参数,避免在训练中更新。- 如果指定了
swav
参数,将使用特定于SWAV方法的预训练权重。 num_channels
根据使用的ResNet版本设置网络中特定层的通道数。- 在模型参数中,除了
layer2
、layer3
、layer4
之外的层的参数都设置为不计算梯度,以冻结这些层的权重。对于layer2
、layer3
、layer4
中的参数,则根据requires_grad
参数决定是否计算梯度。
整体而言,Backbone
类提供了一个灵活的方式来配置和初始化ResNet主干网络,以适应不同的训练需求和迁移学习场景。
SWAV预训练权重 这是啥
SWAV(Swiss Army Vector)预训练权重是一种用于视觉模型的特殊预训练权重集合。"Swiss Army"这个名字暗示了这种预训练方法的多功能性和通用性,就像瑞士军刀一样。SWAV是一种自监督学习方法,它旨在训练一个模型来生成图像的通用特征表示,这些特征可以用于多种下游任务,如分类、目标检测或分割。
SWAV的关键特点包括:
-
自监督学习:SWAV不依赖于标注数据来学习图像的特征表示。它使用对比学习的方法,通过将图像的不同视图(例如,通过裁剪或颜色变换得到的视图)拉近,并将不匹配的视图推远,来学习区分不同图像的能力。
-
多尺度特征:SWAV学习在多个尺度上表示图像,这有助于模型捕捉不同大小的对象和模式。
-
聚类机制:SWAV使用聚类技术来组织学习到的特征,使得相似的图像特征更接近,不同的图像特征更分散。
-
通用性:SWAV预训练的模型可以作为一个通用的视觉模型,适用于多种视觉任务,而不仅仅是特定的任务。
-
预训练权重:通过SWAV方法预训练得到的权重可以作为下游任务的起点,通常可以提高模型在这些任务上的性能。
在你提供的代码上下文中,如果使用SWAV预训练权重,模型会加载这些权重并在此基础上进行微调或直接用于推理。这通常有助于提高模型在特定任务上的性能,特别是当可用的标注数据有限时。通过加载SWAV预训练权重,Backbone
类配置的ResNet模型能够利用这种强大的特征表示能力。