import torch
import torch.nn as nn
class Conv2d_BN(torch.nn.Sequential):
# 类注释:Conv2d_BN 类结合了卷积和批量归一化操作,用于训练时的动态权重调整,
# 并提供了一个静态权重部署的方法。
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, resolution=-10000):
# 初始化方法,创建一个卷积层和批量归一化层的序列
super().__init__() # 调用基类的初始化方法
# 添加一个2D卷积层到序列中
self.add_module('c', nn.Conv2d(
in_channels=a, # 输入通道数
out_channels=b, # 输出通道数
kernel_size=ks, # 卷积核大小
stride=stride, # 步长
padding=pad, # 补零填充
dilation=dilation, # 扩张率
groups=groups, # 组数,用于分组卷积
bias=False)) # 没有偏置项,因为批量归一化层会学习偏置
# 添加一个2D批量归一化层到序列中
self.add_module('bn', nn.BatchNorm2d(num_features=b))
# 初始化批量归一化的权重为1,偏置为0
nn.init.constant_(self.bn.weight, bn_weight_init)
nn.init.constant_(self.bn.bias, 0)
@torch.no_grad() # 装饰器表示该方法内部不需要梯度计算
def switch_to_deploy(self):
# 转换方法,将训练时的动态权重转换为静态权重,便于模型部署
c, bn = self._modules.values() # 获取卷积层和批量归一化层的引用
# 计算批量归一化的权重缩放因子
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
# 将缩放因子应用于卷积层权重
w = c.weight * w[:, None, None, None]
# 计算新的偏置项,考虑了批量归一化的偏置和均值
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
# 创建一个新的静态卷积层,融合了卷积和批量归一化的权重和偏置
m = nn.Conv2d(in_channels=w.size(1) * self.c.groups,
out_channels=w.size(0),
kernel_size=w.shape[2:],
stride=self.c.stride,
padding=self.c.padding,
dilation=self.c.dilation,
groups=self.c.groups)
m.weight.data.copy_(w) # 复制融合后的权重到新卷积层
m.bias.data.copy_(b) # 复制融合后的偏置到新卷积层
return m # 返回融合后的静态卷积层
# 测试代码
if __name__ == '__main__':
# 创建一个Conv2d_BN模块实例,例如:输入通道16,输出通道24,卷积核大小3x3,步长1,填充1
conv2dbn = Conv2d_BN(16, 24, ks=3, stride=1, pad=1)
# 打印模块
print(conv2dbn)
# 转换为部署模式的静态权重卷积层
static_conv = conv2dbn.switch_to_deploy()
# 打印转换后的静态卷积层
print(static_conv)
功能解释:
Conv2d_BN
类在初始化时创建了一个没有偏置的卷积层,因为批量归一化层会学习到偏置项。批量归一化层的权重和偏置分别被初始化为 1 和 0。switch_to_deploy
方法在不需要梯度的情况下运行,它将卷积层的权重和批量归一化层的参数融合,生成一个新的静态卷积层,其中融合了归一化的权重和偏置,便于模型部署。- 融合的过程涉及到使用批量归一化的均值和方差来调整卷积层的权重和偏置,从而创建一个新的等效卷积层,该层在推理时可以不依赖于批量归一化层的动态调整。
import torch
from torch.nn import Module, ModuleList, Sequential, ReLU, Conv2d, BatchNorm2d
import itertools
class CascadedGroupAttention(Module):
'''
CascadedGroupAttention 类实现了级联群注意力机制,用于增强特征多样性,并逐步精化特征表示。
'''
def __init__(self, dim, num_heads=4,
attn_ratio=4,
resolution=7,
kernels=[5, 5, 5, 5], ):
super().__init__() # 调用基类的构造函数
# 计算每个头的键(key)维度,这里是输入通道维度除以16
key_dim = dim // 16
self.num_heads = num_heads # 注意头的数量
self.scale = key_dim ** -0.5 # 缩放因子,用于调整注意力分数
self.key_dim = key_dim # 每个头的键(key)维度
self.d = int(attn_ratio * key_dim) # 值(value)维度,是键维度的attn_ratio倍
self.attn_ratio = attn_ratio # 值维度与查询维度的比例
# 初始化查询、键、值和深度卷积模块的列表
qkvs = []
dws = []
for i in range(num_heads): # 对于每个注意力头
# 创建一个Conv2d_BN模块,用于生成每个头的查询、键和值
qkvs.append(Conv2d_BN(dim // num_heads, self.key_dim * 2 + self.d, resolution=resolution))
# 创建一个Conv2d_BN模块,用于在每个头中应用深度卷积
dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim, resolution=resolution))
self.qkvs = ModuleList(qkvs) # 将查询、键、值模块封装成ModuleList
self.dws = ModuleList(dws) # 将深度卷积模块封装成ModuleList
# 定义输出的投影层,用于将注意力输出映射回原始维度
self.proj = Sequential(
ReLU(), # 激活函数
Conv2d_BN(self.d * num_heads, dim, bn_weight_init=0, resolution=resolution) # 投影卷积和批量归一化
)
# 构建注意力偏置索引,用于在每个头中添加不同的偏置
points = list(itertools.product(range(resolution), range(resolution))) # 产生窗口内所有点的组合
N = len(points) # 点的总数
attention_offsets = {} # 存储偏置的字典
idxs = [] # 存储偏置索引的列表
for p1 in points: # 对于窗口内每对点
for p2 in points: # 计算每对点之间的偏移
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets: # 如果偏移未记录,则添加到字典
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset]) # 将偏置索引添加到列表
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) # 将偏置转换为模型参数
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) # 注册偏置索引为常驻缓冲区
# 测试代码可以在这里添加,用于创建类的实例并测试其功能。
CascadedGroupAttention
的类,它实现了一种新颖的注意力机制,称为级联群注意力(Cascaded Group Attention)
功能解释:
CascadedGroupAttention
类通过在每个注意力头上应用不同的深度卷积和不同的输入分割来增强特征的多样性。- 在初始化过程中,它创建了查询、键、值生成模块(
qkvs
)和深度卷积模块(dws
),每个模块都包装在ModuleList
中。 - 注意力偏置是通过计算窗口内每对点之间的偏移来构建的,这些偏置被用作模型的参数,以便在训练过程中进行学习。
- 类中定义的
proj
属性是一个包含ReLU激活函数和投影卷积层的序列,用于将注意力机制的输出映射回原始的输入通道维度。 register_buffer
方法用于注册一个不会被模型参数梯度更新的缓冲区,这里用于存储偏置索引。
@torch.no_grad() # 装饰器,表示这个函数执行时不计算梯度
def train(self, mode=True):
# 训练模式切换,如果mode为True,则切换到训练模式,否则切换到评估模式
super().train(mode) # 调用基类的train方法
if mode and hasattr(self, 'ab'): # 如果处于训练模式且之前已经计算过静态权重,则删除静态权重
del self.ab
else: # 否则,计算静态权重并保存
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # 前向传播方法,输入x的尺寸为(B,C,H,W),分别代表批次大小、通道数、高度和宽度
B, C, H, W = x.shape # 解构输入张量x的尺寸
trainingab = self.attention_biases[:, self.attention_bias_idxs] # 如果处于训练模式,使用动态计算的注意力偏置
feats_in = x.chunk(len(self.qkvs), dim=1) # 将输入特征沿通道方向分割,每部分输入到一个注意力头
feats_out = [] # 初始化输出特征列表
feat = feats_in[0] # 获取第一个头的输入特征
for i, qkv in enumerate(self.qkvs): # 遍历每个注意力头
if i > 0: # 如果不是第一个头,将前一个头的输出添加到当前头的输入中
feat = feat + feats_in[i]
feat = qkv(feat) # 通过查询、键、值生成模块处理特征
# 分离查询、键、值,其中每个头的查询、键、值通过卷积层生成后分割出来
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1)
q = self.dws[i](q) # 对查询进行深度卷积处理
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # 展平查询、键、值,从(B, C/h, H, W)变为(B, C/h, N)
# 计算注意力分数并应用偏置,其中N是展平后的长度
attn = ((q.transpose(-2, -1) @ k) * self.scale +
(trainingab[i] if self.training else self.ab[i]))
attn = attn.softmax(dim=-1) # 对注意力分数应用softmax归一化
# 计算加权的值,通过矩阵乘法更新特征表示,并重塑回原始空间维度
feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W)
feats_out.append(feat) # 将当前头的输出特征添加到输出列表
# 将所有头的输出特征在通道方向上进行拼接,并通过投影层
x = self.proj(torch.cat(feats_out, 1))
return x # 返回最终的输出特征
功能解释:
train
方法用于根据传入的mode
参数切换模型的训练或评估模式。在训练模式下,动态计算注意力偏置;在评估模式下,使用预先计算好的静态权重。forward
方法定义了模型的前向传播过程。输入特征x
被分割并输入到不同的注意力头中。每个头计算自注意力,并可能将前一个头的输出添加到当前头的输入中(如果是第一个头之外的头)。然后,每个头的输出被级联起来,并通过一个线性层投影回原始的输入维度。
注意:在代码中,self.attention_biases[:, self.attention_bias_idxs]
用于根据窗口内点的相对位置生成注意力偏置,这些偏置在训练和评估模式下可能不同。在训练模式下,可能使用动态计算的偏置,而在评估模式下,则使用预先计算并存储的静态偏置 self.ab
。