Incomplete Multimodal Industrial Anomaly Detection via Cross-Modal Distillation
1、Background
近年来,基于3D点云和RGB图像的多模态工业异常检测(IAD)研究强调了利用模态间的冗余性和互补性对于精确分类和分割的重要性。
在项目中,提出了CMDIAD方法,一个用于IAD的跨模态蒸馏框架,以证明多模态训练、少模态推理(Multi-modal Training, Few-modal Inference,MTFI)流程的可行性。
与典型的知识蒸馏(KD)方法不同,传统的方法涉及接受相同模态输入的教师和学生网络,而跨模态KD的学生网络试图基于额外的模态生成与教师相同的输出。
换句话说,跨模态KD能够通过建立一种模态与另一种模态之间的映射来估计缺失的信息。
通过跨模态KD来增强无监督IAD方法,开发一个多模态训练、少模态推理(MTFI)流程,使模型能够用多种检查方法的数据进行训练,但只使用其中一种方法进行推理。
CMDIAD基于IAD的跨模态蒸馏框架,学习生成缺失模态的跨模态幻觉,目标是调查MTFI流程的可行性以及可以跨模态传输的信息,并采用了基于无监督记忆库方法。
CMDIAD框架的 核心思想是训练可学习的跨模态KD网络,以模拟在训练阶段仅出现的给定模态的特权信息,如特征或输入,然后尝试在推理阶段生成其缺失的信息。
Supplement
跨模态KD网络(Cross-Modal Knowledge Distillation,简称KD)是指一种特殊的知识蒸馏方法,它用于处理多模态学习中缺失模态的情况。在多模态学习中,数据通常包含来自不同源(如视觉、声音、文本等)的多种类型的信息。理想情况下,模型训练和推理都希望能够利用所有这些模态的信息。然而,在现实世界的应用中,某些模态的数据可能不可用或不完整。
跨模态KD网络的核心思想是 利用一种模态的已知信息来估计或生成另一种模态的缺失信息。
这样,即使在缺少某些模态数据的情况下,模型也能够进行有效的推理。具体来说,跨模态KD网络通过以下步骤实现:
- 教师网络和学生网络:在传统的知识蒸馏中,教师网络通常是预训练的深度神经网络,它能够处理完整的模态数据;学生网络则尝试学习教师网络的知识,但学生网络可能只接触到部分模态的数据。
- 模态融合:在跨模态KD中,学生网络被训练来预测或生成那些在推理时可能不可用的模态的信息。这通常涉及到从一种模态(如RGB图像)生成另一种模态(如3D点云)的特征表示。
- 蒸馏:通过训练,学生网络学习到了如何从一种模态映射到另一种模态。这种映射可以是直接的(如特征到特征的映射),也可以是间接的(如输入到输入的映射)。
- 损失函数:在训练过程中,通常会定义一个损失函数来衡量学生网络生成的模态信息与真实模态信息之间的差异。通过最小化这个损失,学生网络能够学习到如何更好地估计或生成缺失的模态信息。
- 推理:在推理阶段,即使某些模态的数据不可用,学生网络也能够利用它在训练阶段学到的知识来生成缺失模态的估计,从而进行有效的异常检测。
2、Method
核心思想:
- 多模态训练,少模态推理(MTFI):模型在训练时使用来自多个检查方法的数据(多模态),但在推理时仅使用其中一种模态的数据。这样可以在实际应用中节省成本和时间,因为不是所有样本都需要通过所有检查方法进行评估。
- 跨模态知识蒸馏(KD):通过跨模态KD,模型能够利用一种模态的信息来估计或生成另一种缺失模态的信息。例如,使用RGB图像信息来生成对应的3D点云信息。
- 记忆库方法:使用记忆库来存储正常样本的特征,并在推理时比较测试样本与这些正常特征的差异,以识别异常。
MTFI流程(只使用PCs进行推理):
- 多模态训练阶段:
- 在这个阶段,模型使用来自多个模态(例如,RGB图像和点云)的数据进行训练。
- 利用跨模态知识蒸馏技术,模型学习如何从一种模态(如RGB图像)生成另一种模态(如点云)的特征表示。
- 通过这种方式,模型能够理解不同模态之间的关系,并学习如何利用一种模态的信息来补充另一种模态的缺失信息。
- 少模态推理阶段:
- 在这个阶段,模型仅使用一种模态的数据(在Figure 2的情况下,只使用点云数据)来进行推理。
- 即使在训练时使用了多模态数据,模型也能够在推理时只使用点云数据来检测异常。
- 通过在训练阶段学习的跨模态映射,模型能够生成缺失模态(如RGB图像)的幻觉,并利用这些幻觉来进行有效的异常检测。
直观解释:
- 训练阶段:模型像“学生”一样学习如何从多种类型的数据中提取和生成有用的特征。
- 推理阶段:模型像“老师”一样应用学到的知识,即使在缺少某些数据类型的情况下也能做出决策。
pseudo-code
# point_cloud_feature_extractor 和 rgb_feature_extractor 是预训练的模型,用于从点云和RGB图像中提取特征。
# F2F_network、F2I_network 和 I2F_network 是跨模态蒸馏网络,用于不同模态间的特征转换。
# select_coreset 是一个函数,用于从正常样本特征中选择一个代表性的子集来构建记忆库。
# compute_anomaly_score 是一个函数,用于计算给定测试样本相对于记忆库的异常分数。
# fuse_scores 是一个函数,用于融合不同模态的异常分数,以得到一个综合的异常检测结果。
# threshold 是一个阈值,用于判断测试样本是否异常。
# 步骤1: 特征提取
def feature_extraction(PCs, RGB_image):
FP_C = point_cloud_feature_extractor(PCs) # 使用Point-MAE
FRGB = rgb_feature_extractor(RGB_image) # 使用DINO
return FP_C, FRGB
# 步骤2: 跨模态蒸馏
def cross_modal_distillation(FP_C, FRGB):
# 特征到特征蒸馏
H_fRGB = F2F_network(FP_C)
H_FP_C = F2F_network(FRGB)
# 特征到输入蒸馏
HiRGB = F2I_network(FP_C)
HfRGB = RGB_feature_extractor(HiRGB)
# 输入到特征蒸馏
HfRGB = I2F_network(PCs)
return H_fRGB, H_FP_C, HfRGB
# 步骤3: 记忆库构建
def build_memory_bank(FP_C, FRGB, normal_samples):
M_P_C = select_coreset(FP_C, normal_samples)
M_RGB = select_coreset(FRGB, normal_samples)
return M_P_C, M_RGB
# 步骤4: 推理与异常检测
def anomaly_detection(FP_C, M_P_C, H_fRGB, M_RGB):
# 计算异常分数
anomaly_score_PCs = compute_anomaly_score(FP_C, M_P_C)
anomaly_score_RGB = compute_anomaly_score(H_fRGB, M_RGB)
# 融合异常分数
final_score = fuse_scores(anomaly_score_PCs, anomaly_score_RGB)
# 阈值判断异常
return final_score > threshold
# 主流程
def main(PCs, RGB_image, normal_samples):
FP_C, FRGB = feature_extraction(PCs, RGB_image)
H_fRGB, H_FP_C, HfRGB = cross_modal_distillation(FP_C, FRGB)
M_P_C, M_RGB = build_memory_bank(FP_C, FRGB, normal_samples)
anomaly_detected = anomaly_detection(FP_C, M_P_C, H_fRGB, M_RGB)
return anomaly_detected
算法流程:
- 特征提取
- 使用预训练的深度学习模型(如DINO和Point-MAE)分别对RGB图像和3D点云进行特征提取。
- 跨模态蒸馏
- 特征到特征蒸馏(F2F):直接从一种模态的特征图生成另一种模态的特征图。
- 特征到输入蒸馏(F2I):从一种模态的特征图中生成另一种模态的输入数据。
- 输入到特征蒸馏(I2F):从一种模态的输入数据生成另一种模态的特征图。
- 优化
- 通过最小化真实特征和生成特征之间的距离来训练跨模态蒸馏网络。
- 记忆库构建
- 使用核心集选择方法来构建记忆库,确保记忆库尽可能全面且不包含重复特征。
- 决策层融合
- 在推理时,使用单类支持向量机(OCSVM)比较测试样本的特征与记忆库中的特征,以计算异常分数。
- 推理
- 根据计算出的异常分数,模型对测试样本进行分类和分割,以识别异常。
3、Experiments
4、Conclusion
开发了一个基于跨模态蒸馏驱动的多模态工业异常检测(IAD)的有效框架。