Rethinking Reverse Distillation for Multi-Modal Anomaly Detection
1、Background
异常检测方法主要分为基于嵌入的方法、合成和重建。基于嵌入的方法描述了提取特征的相应分布,异常通过测量测试图像的特征与估计分布之间的距离来检测。基于合成的方法估计异常自由样本和合成异常数据之间的决策边界进行检测。相反,基于重建的方法要么恢复输入,要么恢复中间级特征。
提出了一种新的 多模态逆向蒸馏(MMRD)范式
,包括一个被冻结的多模态教师编码器来生成蒸馏目标,以及一个可学习的多模态学生解码器,目标是从教师那里恢复多模态表示。
具体来说,教师通过孪生结构从不同的模态中提取互补的视觉特征,然后无参数地融合这些来自不同层次的信息作为蒸馏的目标。对于学生来说,它从教师对正常训练数据的表示中学习模态相关的先验知识,并在它们之间进行交互,以形成目标重建的多模态表示。
supplement
知识蒸馏(Knowledge Distillation,KD)是一种模型压缩技术,它允许一个小型的学生模型通过模仿一个大型的教师模型来学习,目的是让学生模型达到与教师模型相似的性能。这个过程涉及到从教师模型中提取知识,然后将其传递给学生模型。知识蒸馏通常用于优化模型的大小和速度,同时尽量保持其准确性。
逆向蒸馏(Reverse Knowledge Distillation,RKD)是知识蒸馏的一个变种。逆向蒸馏不是将知识从大模型传递到小模型,而是 将知识从已经训练好的小模型传递回大模型
。这样做的目的是为了改进大模型的性能,尤其是在处理小模型表现更好的情况下。
KD & RKD 的区别:
- 目标不同:
- 知识蒸馏的目标是创建一个小型、高效的模型;
- 逆向蒸馏的目标是改进大型模型的性能。
- 知识流向不同:
- 在知识蒸馏中,知识从教师模型流向学生模型;
- 在逆向蒸馏中,知识从小模型流向大模型。
- 应用场景不同:
- 知识蒸馏通常用于模型压缩和加速,适用于资源受限的环境;
- 逆向蒸馏则用于提升大模型在特定任务上的性能。
- 计算方法:
- 知识蒸馏:通常涉及到软目标的概念,即教师模型的输出概率分布,学生模型在训练时会尝试模仿这些软目标。
- 逆向蒸馏:通过小模型的输出被用作引导大模型训练的软目标,大模型通过学习小模型的决策来改进自己的性能。
2、Method
主要组件:
- 冻结的多模态教师编码器(Frozen Multi-Modal Teacher Encoder)
- 这个编码器用于从不同的输入模态(如RGB图像和辅助模态如深度图)提取特征。
- 它采用孪生网络结构来处理每种模态,并通过共享的卷积层和独立的批量归一化(BN)层提取特征。
- 教师编码器的参数在训练过程中是固定的(即“冻结的”),意味着它们不会更新。
- 无参数模态调制(Parameter-free Modality Modulation)
- 此模块用于融合来自不同模态的特征,生成用于蒸馏的目标特征。
- 融合过程不涉及任何可学习的参数,确保了蒸馏目标的稳定性。
- 可学习的多模态学生解码器(Learnable Multi-Modal Student Decoder)
- 学生解码器的目标是从教师编码器提供的目标特征中恢复多模态表示。
- 它通过学习从正常训练数据中提取的模态相关先验知识来实现这一目标。
工作流程:
- 特征提取
- 使用孪生网络结构分别从RGB图像和辅助模态提取特征。
- 这些特征被送入无参数模态调制模块,以生成蒸馏目标。
- 模态相关先验的生成
- 学生解码器通过学习从教师网络的正常数据表示中提取的“原型”来生成模态相关的先验知识。
- 多模态先验的生成
- 学生解码器通过内部模态交互和跨模态交互生成更精细的多模态表示,以帮助恢复目标特征。
- 重建和异常检测
- 学生解码器的输出与教师编码器的目标特征进行比较,计算重构误差。
- 在推理阶段,通过计算像素级特征之间的相似性来生成异常图,从而实现异常检测和定位。
pseudo-code
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个孪生网络结构,用于多模态特征提取
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.shared_conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(num_features=64)
self.bn2 = nn.BatchNorm2d(num_features=64)
def forward(self, rgb_input, depth_input):
# 提取RGB特征
rgb_features = self.bn1(self.shared_conv(rgb_input))
# 提取深度特征
depth_features = self.bn2(self.shared_conv(depth_input))
return rgb_features, depth_features
# 定义无参数模态调制模块
class ModalityModulation(nn.Module):
def __init__(self):
super(ModalalityModulation, self).__init__()
def forward(self, rgb_features, depth_features):
# 计算融合权重
alpha = torch.sigmoid((rgb_features - depth_features).pow(2).mean(dim=1, keepdim=True))
# 融合特征
fused_features = rgb_features + alpha * depth_features
return fused_features
# 定义学生网络
class StudentNetwork(nn.Module):
def __init__(self):
super(StudentNetwork, self).__init__()
self.decoder = nn.Sequential(
nn.Linear(64 * 64 * 64, 128),
nn.ReLU(),
nn.Linear(128, 64 * 64 * 64),
nn.Sigmoid()
)
def forward(self, fused_features):
# 重建特征
reconstructed_features = self.decoder(fused_features.view(fused_features.size(0), -1))
return reconstructed_features
# 初始化网络
teacher_encoder = SiameseNetwork()
modality_fusion = ModalityModulation()
student_decoder = StudentNetwork()
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(student_decoder.parameters(), lr=0.001)
# 训练过程
for epoch in range(num_epochs):
# 获取输入数据
rgb_input, depth_input = load_data()
teacher_rgb_features, teacher_depth_features = teacher_encoder(rgb_input, depth_input)
fused_features = modality_fusion(teacher_rgb_features, teacher_depth_features)
# 学生网络重建
student_output = student_decoder(fused_features)
# 计算损失
loss = criterion(student_output, fused_features)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
# 推理和异常检测
def detect_anomalies(rgb_input, depth_input):
teacher_rgb_features, teacher_depth_features = teacher_encoder(rgb_input, depth_input)
fused_features = modality_fusion(teacher_rgb_features, teacher_depth_features)
student_reconstructed_features = student_decoder(fused_features)
# 计算异常图
anomaly_map = torch.abs(student_reconstructed_features - fused_features)
return anomaly_map
# 使用模型进行异常检测
anomaly_map = detect_anomalies(rgb_input, depth_input)
3、Experiments
🐂🐎。。。
4、Conclusion
- 提出了一种用于异常检测的多模态逆向蒸馏(MMRD)范式;
- 使用一个被冻结的多模态教师编码器来生成多模态蒸馏目标,供可学习的多模态学生解码器恢复。