🚩🚩🚩Transformer实战-系列教程总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8(训练函数/损失函数)
7、BackboneBase类
位置:models/backbone.py/BackboneBase类
7.1 构造函数
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
for name, parameter in backbone.named_parameters():
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
- 定义一个继承nn.Module的类
- 构造函数,传入4个参数:
backbone
:一个nn.Module
对象,代表用于特征提取的骨架网络train_backbone
:是否训练backbonenum_channels
:backbone通道数return_interm_layers
:是否返回backbone的中间层输出
- 初始化
- 遍历backbone的所有参数,
named_parameters()
方法返回网络中所有参数的迭代器,包括参数的名称和值 - 如果train_backbone设置为False,且不训练
layer2
、layer3
和layer4
,也就是说如果train_backbone为False,backbone的所有层的所有参数都不需要训练,即所有层都被冻住 - 不需要训练的参数的
requires_grad
属性设置为False
- 根据
return_interm_layers
的值 - 选择性地设置
return_layers
字典 - 一个层对应一个值
- 这个字典定义了哪些层的输出将被返回
- 创建
IntermediateLayerGetter
实例,它封装了backbone,根据return_layers
字典决定返回哪些层的输出,IntermediateLayerGetter来自torchvision - num_channels
7.2 前向传播
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
- 前向传播函数,接收
NestedTensor
对象作为输入 - xs ,获取指定层的输出
- out,初始化一个字典,存储每个返回层的输出及其对应的新掩码
- 遍历xsitems
- 获取mask
- 确认mask存在
- 计算新的掩码
- 将输出和新掩码封装为
NestedTensor
对象 - 返回out字典
8、Backbone类
8.1 Backbone类
位置:models/backbone.py/Backbone类
class Backbone(BackboneBase):
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
- 定义一个继承BackboneBase的类
- 初始化方法,接受四个参数:
name
:字符串,指定要使用的ResNet模型的名称(如resnet50
、resnet101
等)train_backbone
:布尔值,指示是否训练backbonereturn_interm_layers
:布尔值,指示是否返回backbone的中间层输出dilation
:布尔值,指示在网络的最后几层是否应用空洞卷积(dilation)以增加感受野
- 通过
torchvision.models
动态获取指定名称的ResNet模型 - replace_stride_with_dilation,最后一个stage应用空洞卷积
pretrained
,根据is_main_process()
的返回值决定是否加载预训练权重,norm_layer
设置为FrozenBatchNorm2d
,在backbone中使用冻结的批归一化- 根据ResNet模型的不同,设置不同的输出通道数
- 调用基类
BackboneBase
的初始化方法,传递创建的backbone
实例和其他参数
这个Backbone
类通过提供对ResNet模型的封装,允许用户灵活地选择不同的配置,例如是否训练Backbone、是否返回中间层输出以及是否在网络后段应用空洞卷积。同时,通过使用冻结的批量归一化层,可以在不调整BN层参数的情况下,利用预训练的模型进行特征提取
8.2 build_backbone()函数
位置:models/backbone.py/build_backbone()函数
本项目的backbone,主要是调用resnet,用来提取图像特征,进而构建图像序列做Transformer的输入,backbone的构建主要通过这个函数来实现:
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model
这段代码定义了一个名为build_backbone
的函数,用于根据提供的参数构建一个含有位置编码的骨架网络模型。以下是对这段代码的逐行解释:
- 函数
build_backbone
,接收命令行参数 - position_embedding ,调用
build_position_encoding
,函数构建位置编码 - 通过lr_backbone(backbone的学习率)是否大于0来决定是否训练backbone
args.masks
指示是否需要骨架网络返回中间层的输出- 通过Backbone类构建backbone
- 通过Joiner类传入backbone和位置编码,建立backbone模型
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8(训练函数/损失函数)