代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# 自定义注意力机制类
class Attention(nn.Module):
def __init__(self, p=2):
super(Attention, self).__init__()
self.p = p # 特征图的幂次
def at(self, f):
# 计算每层特征图的注意力图
B, C, H, W = f.size()
attention_map = F.softmax(f.pow(self.p).mean(1).view(B, -1), dim=1).view(B, H, W)
print(f"Attention map size: {attention_map.size()}\n{attention_map}") # 打印注意力图的维度和内容
return attention_map
def at_loss(self, f_s, f_t):
# 计算浅层特征图(f_s)与深层特征图(f_t)的注意力损失
loss = (self.at(f_s) - self.at(f_t)).pow(2).mean()
print(f"Attention loss between layers: {loss.item()}") # 打印两层之间的注意力损失
return loss
def forward(self, features):
# features为模型输出的多个特征图
loss = 0
for i in range(len(features) - 1):
print(f"Layer {i+1} feature map size: {features[i].size()}")
print(f"Layer {i+2} feature map size: {features[i+1].size()}")
loss += self.at_loss(features[i], F.interpolate(features[i+1], scale_factor=2., mode="bilinear"))
return loss
# 自定义简单的CNN模型,输出多个层的特征图
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2) # 池化层,减小特征图尺寸
def forward(self, x):
f1 = self.pool(F.relu(self.conv1(x))) # 第一层特征图
f2 = self.pool(F.relu(self.conv2(f1))) # 第二层特征图
f3 = self.pool(F.relu(self.conv3(f2))) # 第三层特征图
return [f1, f2, f3]
# 定义损失和优化器
criterion_cls = nn.CrossEntropyLoss() # 分类损失(此处只是模拟)
attention_module = Attention() # 注意力蒸馏模块
# 初始化模型、输入数据和标签
model = SimpleCNN()
inputs = torch.randn(1, 3, 64, 64) # 模拟一个 64x64 的输入图片
targets = torch.tensor([0]) # 模拟分类目标
# 前向传播,获取各层的特征图
features = model(inputs)
print("\nFeature maps generated by each layer:")
for i, f in enumerate(features):
print(f"Feature map from layer {i+1}: shape = {f.shape}")
# 计算SAD的注意力损失
print("\nCalculating attention distillation loss:")
loss_div = attention_module(features)
# 计算分类损失(模拟分类任务)
logits = torch.randn(1, 10) # 假设有10个分类
loss_cls = criterion_cls(logits, targets)
# 总损失 = 分类损失 + 注意力蒸馏损失
total_loss = loss_cls + loss_div
print(f"\nTotal loss (classification + attention distillation): {total_loss.item()}")
详细解释
注意力机制
1. 平方操作的作用:
当我们对特征图中的每个像素值进行平方时:
- 高激活值(>1):会被进一步放大,比如一个值为2的像素点,平方后变为4。这个过程将模型认为重要的激活区域显著增强,提升它在注意力图中的相对权重。
- 低激活值(<1):平方后变得更小,比如0.5平方后变为0.25。这个过程会进一步减弱那些激活较低的区域,使它们在最终的注意力图中被“忽略”。
attention_map = F.softmax(f.pow(self.p).mean(1).view(B, -1), dim=1).view(B, H, W)
-
f.pow(self.p)
:将特征图f
中的每个值进行p
次方操作,默认情况下p=2
,也就是每个像素值都进行平方。 -
f.mean(1)
:对特征图的通道维度(即第2个维度)进行平均,意味着我们将多个通道的信息压缩成一个单通道的二维特征图。这样可以将所有通道的信息进行综合,使网络对整个图像有一个全局的感知。- 假设
f
的维度是(B, C, H, W)
,其中:B
是批次大小(batch size),C
是通道数(例如在RGB图像中C=3
),H
和W
是特征图的高度和宽度。
mean(1)
操作后,特征图的通道数被消除,结果是一个形状为(B, H, W)
的张量。
- 假设
-
view(B, -1)
:- 这里的
view(B, -1)
是将特征图转换为(B, N)
的形状,其中N
表示特征图中的每个位置(像素)的总数(即H * W
)。 -1
表示自动推断这个维度的大小,它会根据总的元素个数和其他已知的维度(B
和H*W
)进行推断。例如,如果B=32
且特征图大小是32x32
,那么view(B, -1)
会把它转换成(32, 1024)
的形状。
这样做的目的是将二维特征图展平成一个一维向量,使得我们可以在整个特征图上应用
softmax
。 - 这里的
-
F.softmax(..., dim=1)
:softmax
操作将展平后的特征图中的值归一化,使得它们形成一个概率分布。这里dim=1
表示对每个样本的所有像素点(即H*W
个像素)进行softmax
,确保每个样本的特征图中的注意力值和为1。
-
view(B, H, W)
:- 经过
softmax
之后,特征图的形状是(B, H*W)
,也就是展平后的二维特征图。 - 使用
view(B, H, W)
将展平的一维特征图重新变回原来的二维形状(B, H, W)
,这是为了继续后续的特征图处理,保持特征图的空间结构。
- 经过
简单例子:
假设我们有一个批次大小为 B=2
,特征图大小为 H=4, W=4
,也就是说特征图的初始形状为 (2, C, 4, 4)
。
- 经过
mean(1)
操作,特征图的通道维度C
被消除,结果变为(2, 4, 4)
。 - 然后
view(B, -1)
将其展平为(2, 16)
,也就是每个样本有 16 个像素点。 - 通过
softmax
操作对这 16 个像素点进行归一化,确保它们的注意力值和为1。 - 最后
view(B, H, W)
将展平的一维张量恢复为(2, 4, 4)
的二维形状。
总结:
view(B, -1)
是为了将二维特征图展平为一维,以便在整个特征图上应用softmax
,计算每个位置的注意力值。view(B, H, W)
是为了在softmax
之后将展平的特征图重新恢复为原始的二维形状,以便后续操作继续使用原始的空间结构。