【2024工业3D异常检测文献】MMRD: 基于多模态逆向蒸馏学生-教师网络的异常检测方法

Rethinking Reverse Distillation for Multi-Modal Anomaly Detection

1、Background

异常检测方法主要分为基于嵌入的方法、合成和重建。基于嵌入的方法描述了提取特征的相应分布,异常通过测量测试图像的特征与估计分布之间的距离来检测。基于合成的方法估计异常自由样本和合成异常数据之间的决策边界进行检测。相反,基于重建的方法要么恢复输入,要么恢复中间级特征。

提出了一种新的 多模态逆向蒸馏(MMRD)范式 ,包括一个被冻结的多模态教师编码器来生成蒸馏目标,以及一个可学习的多模态学生解码器,目标是从教师那里恢复多模态表示。

具体来说,教师通过孪生结构从不同的模态中提取互补的视觉特征,然后无参数地融合这些来自不同层次的信息作为蒸馏的目标。对于学生来说,它从教师对正常训练数据的表示中学习模态相关的先验知识,并在它们之间进行交互,以形成目标重建的多模态表示。

supplement

知识蒸馏(Knowledge Distillation,KD)是一种模型压缩技术,它允许一个小型的学生模型通过模仿一个大型的教师模型来学习,目的是让学生模型达到与教师模型相似的性能。这个过程涉及到从教师模型中提取知识,然后将其传递给学生模型。知识蒸馏通常用于优化模型的大小和速度,同时尽量保持其准确性。

逆向蒸馏(Reverse Knowledge Distillation,RKD)是知识蒸馏的一个变种。逆向蒸馏不是将知识从大模型传递到小模型,而是 将知识从已经训练好的小模型传递回大模型 。这样做的目的是为了改进大模型的性能,尤其是在处理小模型表现更好的情况下。

KD & RKD 的区别:

  • 目标不同:
    • 知识蒸馏的目标是创建一个小型、高效的模型;
    • 逆向蒸馏的目标是改进大型模型的性能。
  • 知识流向不同:
    • 在知识蒸馏中,知识从教师模型流向学生模型;
    • 在逆向蒸馏中,知识从小模型流向大模型。
  • 应用场景不同:
    • 知识蒸馏通常用于模型压缩和加速,适用于资源受限的环境;
    • 逆向蒸馏则用于提升大模型在特定任务上的性能。
  • 计算方法:
    • 知识蒸馏:通常涉及到软目标的概念,即教师模型的输出概率分布,学生模型在训练时会尝试模仿这些软目标。
    • 逆向蒸馏:通过小模型的输出被用作引导大模型训练的软目标,大模型通过学习小模型的决策来改进自己的性能。

2、Method

主要组件:

  • 冻结的多模态教师编码器(Frozen Multi-Modal Teacher Encoder)
    • 这个编码器用于从不同的输入模态(如RGB图像和辅助模态如深度图)提取特征。
    • 它采用孪生网络结构来处理每种模态,并通过共享的卷积层和独立的批量归一化(BN)层提取特征。
    • 教师编码器的参数在训练过程中是固定的(即“冻结的”),意味着它们不会更新。
  • 无参数模态调制(Parameter-free Modality Modulation)
    • 此模块用于融合来自不同模态的特征,生成用于蒸馏的目标特征。
    • 融合过程不涉及任何可学习的参数,确保了蒸馏目标的稳定性。
  • 可学习的多模态学生解码器(Learnable Multi-Modal Student Decoder)
    • 学生解码器的目标是从教师编码器提供的目标特征中恢复多模态表示。
    • 它通过学习从正常训练数据中提取的模态相关先验知识来实现这一目标。

在这里插入图片描述

工作流程:

  • 特征提取
    • 使用孪生网络结构分别从RGB图像和辅助模态提取特征。
    • 这些特征被送入无参数模态调制模块,以生成蒸馏目标。
  • 模态相关先验的生成
    • 学生解码器通过学习从教师网络的正常数据表示中提取的“原型”来生成模态相关的先验知识。
  • 多模态先验的生成
    • 学生解码器通过内部模态交互和跨模态交互生成更精细的多模态表示,以帮助恢复目标特征。
  • 重建和异常检测
    • 学生解码器的输出与教师编码器的目标特征进行比较,计算重构误差。
    • 在推理阶段,通过计算像素级特征之间的相似性来生成异常图,从而实现异常检测和定位。

pseudo-code

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个孪生网络结构,用于多模态特征提取
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.shared_conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=64)
        self.bn2 = nn.BatchNorm2d(num_features=64)

    def forward(self, rgb_input, depth_input):
        # 提取RGB特征
        rgb_features = self.bn1(self.shared_conv(rgb_input))
        # 提取深度特征
        depth_features = self.bn2(self.shared_conv(depth_input))
        return rgb_features, depth_features

# 定义无参数模态调制模块
class ModalityModulation(nn.Module):
    def __init__(self):
        super(ModalalityModulation, self).__init__()

    def forward(self, rgb_features, depth_features):
        # 计算融合权重
        alpha = torch.sigmoid((rgb_features - depth_features).pow(2).mean(dim=1, keepdim=True))
        # 融合特征
        fused_features = rgb_features + alpha * depth_features
        return fused_features

# 定义学生网络
class StudentNetwork(nn.Module):
    def __init__(self):
        super(StudentNetwork, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(64 * 64 * 64, 128),
            nn.ReLU(),
            nn.Linear(128, 64 * 64 * 64),
            nn.Sigmoid()
        )

    def forward(self, fused_features):
        # 重建特征
        reconstructed_features = self.decoder(fused_features.view(fused_features.size(0), -1))
        return reconstructed_features

# 初始化网络
teacher_encoder = SiameseNetwork()
modality_fusion = ModalityModulation()
student_decoder = StudentNetwork()

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(student_decoder.parameters(), lr=0.001)

# 训练过程
for epoch in range(num_epochs):
    # 获取输入数据
    rgb_input, depth_input = load_data()
    teacher_rgb_features, teacher_depth_features = teacher_encoder(rgb_input, depth_input)
    fused_features = modality_fusion(teacher_rgb_features, teacher_depth_features)
    
    # 学生网络重建
    student_output = student_decoder(fused_features)
    
    # 计算损失
    loss = criterion(student_output, fused_features)
    
    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 推理和异常检测
def detect_anomalies(rgb_input, depth_input):
    teacher_rgb_features, teacher_depth_features = teacher_encoder(rgb_input, depth_input)
    fused_features = modality_fusion(teacher_rgb_features, teacher_depth_features)
    student_reconstructed_features = student_decoder(fused_features)
    
    # 计算异常图
    anomaly_map = torch.abs(student_reconstructed_features - fused_features)
    return anomaly_map

# 使用模型进行异常检测
anomaly_map = detect_anomalies(rgb_input, depth_input)

3、Experiments

🐂🐎。。。

在这里插入图片描述

4、Conclusion

  • 提出了一种用于异常检测的多模态逆向蒸馏(MMRD)范式;
  • 使用一个被冻结的多模态教师编码器来生成多模态蒸馏目标,供可学习的多模态学生解码器恢复。
爬虫Python学习是指学习如何使用Python编程语言来进行网络爬取和数据提取的过程。Python是一种简单易学且功能强大的编程语言,因此被广泛用于爬虫开发。爬虫是指通过编写程序自动抓取网页上的信息,可以用于数据采集、数据分析、网站监测等多个领域。 对于想要学习爬虫的新手来说,Python是一个很好的入门语言。Python的语法简洁易懂,而且有丰富的第三方库和工具,如BeautifulSoup、Scrapy等,可以帮助开发者更轻松地进行网页解析和数据提取。此外,Python还有很多优秀的教程和学习资源可供选择,可以帮助新手快速入门并掌握爬虫技能。 如果你对Python编程有一定的基础,那么学习爬虫并不难。你可以通过观看教学视频、阅读教程、参与在线课程等方式来学习。网络上有很多免费和付费的学习资源可供选择,你可以根据自己的需求和学习风格选择适合自己的学习材料。 总之,学习爬虫Python需要一定的编程基础,但并不难。通过选择合适的学习资源和不断实践,你可以逐步掌握爬虫的技能,并在实际项目中应用它们。 #### 引用[.reference_title] - *1* *3* [如何自学Python爬虫? 零基础入门教程](https://blog.csdn.net/zihong523/article/details/122001612)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [新手小白必看 Python爬虫学习路线全面指导](https://blog.csdn.net/weixin_67991858/article/details/128370135)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值