input = torch.randn(1,64,32,32)
model = CascadedGroupAttention(dim=64,resolution=32) #resolution要求和图片大小一样
output = model(input)
print('input_size:', input.size())
print('output_size:', output.size())
#查看模型结构(每个层的结构)
print(model)
# 查看模型名称和大小
for name, param in model.named_parameters():
print(name, param.size())
# 查看模块的参数数量
param_count = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {param_count}")
输出:
'''
input_size: torch.Size([1, 64, 32, 32])
output_size: torch.Size([1, 64, 32, 32])
CascadedGroupAttention(
(qkvs): ModuleList(
(0-3): 4 x Conv2d_BN(
(c): Conv2d(16, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(dws): ModuleList(
(0-3): 4 x Conv2d_BN(
(c): Conv2d(4, 4, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=4, bias=False)
(bn): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(proj): Sequential(
(0): ReLU()
(1): Conv2d_BN(
(c): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
参数量的计算
attention_biases torch.Size([4, 1024])
注意力偏置 (
attention_biases
):
attention_biases
: 参数大小为torch.Size([4, 1024])
,意味着有 4 个头,每个头有 1024 个偏置值,总参数数为4 * 1024 = 4096
。
qkvs.0.c.weight torch.Size([24, 16, 1, 1])
qkvs.0.bn.weight torch.Size([24])
qkvs.0.bn.bias torch.Size([24])
qkvs.1.c.weight torch.Size([24, 16, 1, 1])
qkvs.1.bn.weight torch.Size([24])
qkvs.1.bn.bias torch.Size([24])
qkvs.2.c.weight torch.Size([24, 16, 1, 1])
qkvs.2.bn.weight torch.Size([24])
qkvs.2.bn.bias torch.Size([24])
qkvs.3.c.weight torch.Size([24, 16, 1, 1])
qkvs.3.bn.weight torch.Size([24])
qkvs.3.bn.bias torch.Size([24])
查询、键、值卷积层 (
qkvs
):
- 每个头的
qkvs
模块包含一个Conv2d_BN
,其中卷积层参数为torch.Size([24, 16, 1, 1])
,批量归一化层参数为torch.Size([24])
(权重和偏置)。- 因为有 4 个这样的头,所以这部分的总参数数为
4 * ((24 * 16 * 1 * 1) + 24 + 24)
。
dws.0.c.weight torch.Size([4, 1, 5, 5])
dws.0.bn.weight torch.Size([4])
dws.0.bn.bias torch.Size([4])
dws.1.c.weight torch.Size([4, 1, 5, 5])
dws.1.bn.weight torch.Size([4])
dws.1.bn.bias torch.Size([4])
dws.2.c.weight torch.Size([4, 1, 5, 5])
dws.2.bn.weight torch.Size([4])
dws.2.bn.bias torch.Size([4])
dws.3.c.weight torch.Size([4, 1, 5, 5])
dws.3.bn.weight torch.Size([4])
dws.3.bn.bias torch.Size([4])
深度卷积层 (
dws
):
- 每个深度卷积层的
Conv2d_BN
参数为torch.Size([4, 1, 5, 5])
,批量归一化层参数为torch.Size([4])
。- 同样有 4 个这样的层,总参数数为
4 * ((4 * 1 * 5 * 5) + 4 + 4)
。
proj.1.c.weight torch.Size([64, 64, 1, 1])
proj.1.bn.weight torch.Size([64])
proj.1.bn.bias torch.Size([64])
输出投影层 (
proj
):
- 包含一个ReLU激活函数(无参数),和一个
Conv2d_BN
,其参数为torch.Size([64, 64, 1, 1])
,批量归一化层参数为torch.Size([64])
。- 这部分的参数数量为
((64 * 64 * 1 * 1) + 64 + 64)
Total number of parameters: 10480
现在,我们来计算每个部分的参数数量:
attention_biases
:4 * 1024 = 4096
个参数qkvs
:4 * ((24 * 16 + 24) + 24) = 4 * (384 + 24 + 24) = 4 * 432 = 1728
个参数dws
:4 * ((4 * 1 * 5 * 5) + 4 + 4) = 4 * (100 + 8) = 4 * 108 = 432
个参数proj
:((64 * 64 + 64) + 64) = 4096 + 64 + 64 = 4224
个参数将这些加起来得到总的参数数量:
4096 + 1728 + 432 + 4224 = 10480
'''
qkvs.0.c.weight torch.Size([24, 16, 1, 1])为什么是24 16 1 1
在 CascadedGroupAttention
类中,qkvs
属性是一个由 Conv2d_BN
模块组成的 ModuleList
。每个 Conv2d_BN
模块都包含一个卷积层 (Conv2d
) 后面跟着一个批量归一化层 (BatchNorm2d
)。
对于 qkvs.0.c.weight torch.Size([24, 16, 1, 1])
,这个尺寸表示的是第一个 Conv2d_BN
模块中卷积层的权重张量的形状。我们可以逐个维度解释它的含义:
(qkvs): ModuleList(
(0-3): 4 x Conv2d_BN(
(c): Conv2d(16, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
-
24
:这表示卷积层的输出通道数,也就是它将产生 24 个特征图。在这个上下文中,这些特征图对应于一个注意力头中的查询(query)、键(key)和值(value)的组合。由于每个头的查询、键和值的总维度是self.key_dim * 2 + self.d
,并且这里产生了 24 个输出通道,这意味着每个通道的维度是(self.key_dim * 2 + self.d) / 24
。 -
16
:这表示卷积层的输入通道数。在这个模块中,输入通道数是dim // num_heads
,其中dim
是输入特征的总通道数,num_heads
是注意力头的数量。所以,如果dim
是 64,并且有 4 个头,那么每个头接收的输入通道数是64 / 4 = 16
。 -
1
:这两个1
表示卷积核的高和宽,这里是 1x1 卷积。1x1 卷积通常用于在不改变特征图的空间尺寸的情况下调整通道数,也就是说,它对输入特征图的每个像素进行逐通道的线性变换。
综上所述,qkvs.0.c.weight torch.Size([24, 16, 1, 1])
表示第一个注意力头的卷积层有 24 个输出通道和 16 个输入通道,卷积核的大小是 1x1。这种设计通常用于减少参数数量并允许网络在保持分辨率不变的情况下学习通道之间的复杂交互。
由于每个头的查询、键和值的总维度是
self.key_dim * 2 + self.d
,并且这里产生了 24 个输出通道,这意味着每个通道的维度是(self.key_dim * 2 + self.d) / 24
。
在 CascadedGroupAttention
类中,每个注意力头都有自己的查询(Q)、键(K)和值(V)。这些是通过一个卷积层生成的,卷积层的输出通道数必须能够被分割成 Q、K 和 V,它们的维度分别是:
- 查询(Q):
self.key_dim
- 键(K):
self.key_dim
- 值(V):
self.d
其中 self.key_dim
是每个头的键/查询维度,self.d
是值的维度,且 self.d = int(attn_ratio * key_dim)
。
现在,如果我们有一个卷积层,它的输出通道数是 24,这意味着所有头的 Q、K 和 V 的总和必须等于这个数。由于每个头都会生成 Q、K 和 V,所以如果 num_heads
是头的总数,那么卷积层的输出通道数应该是 (自我.key_dim * 2 + self.d) * num_heads
。
在给出的例子中,卷积层的输出通道数是 24,而不是直接对应于单个头的 Q、K、V 维度总和,这是因为这 24 个通道是分配给所有头的。具体来说,如果 num_heads
是 4,那么每个头的 Q、K、V 将共享这 24 个通道。但是,代码中并没有直接说明 num_heads
的值,而是通过卷积层的输出通道数来间接指定。
如果我们假设 self.key_dim * 2 + self.d
能够整除卷积层的输出通道数(24),那么每个头将获得的通道数是:
然后,每个头的 Q、K 和 V 将从这些通道中进一步分配。例如,如果 self.key_dim
是 4,self.d
是 16(假设 attn_ratio
是 4),那么每个头的 Q、K、V 将从 24 个通道中分配,可能的分配方式是:
- Q: 4 通道
- K: 4 通道
- V: 16 通道
但这种分配取决于 self.key_dim
和 self.d
的具体值,以及卷积层输出通道数如何与头的数量相匹配。在实际代码中,卷积层是为每个头单独定义的,所以每个头都会有自己的 Q、K、V,而不会共享通道。这意味着卷积层的输出通道数应该直接等于 self.key_dim * 2 + self.d
,而不是 24。这可能是之前解释中的混淆之处。每个头的卷积层应该有足够的输出通道来独立地生成 Q、K 和 V,而不需要共享通道。