基于异常检测的CFA算法
算法介绍
基于异常检测的CFA(Coupled-hypersphere-based Feature Adaptation)算法是一种用于目标导向异常定位的深度学习方法,主要通过特征自适应和迁移学习来提高异常检测的精度。以下是对其的详细介绍:
算法原理
CFA算法的核心思想是通过适应目标数据集的特征来实现精细的异常定位。它包括以下关键部分:
-
可学习的补丁描述符(Patch Descriptor):该模块用于学习和嵌入面向目标的特征。它将目标数据集的正常样本中的补丁特征进行学习,使其在记忆特征周围具有较高的密度。
-
耦合超球体(Coupled Hypersphere):CFA通过对比监督学习,以内存库中的记忆特征为中心创建叠加的超球体(即耦合超球体),使正常特征在这些超球体内密集分布。
-
可扩展内存库(Memory Bank):该内存库独立于目标数据集的大小,用于存储从正常样本中提取的初始目标导向特征。通过特征自适应,内存库中的特征能够更好地适应目标数据集。
算法流程
-
特征提取:使用预训练的CNN(如在ImageNet上预训练的模型)对目标数据集的样本进行推断,获取不同尺度的特征图。将这些特征图插值到相同的分辨率后进行连接,生成补丁特征。
-
特征自适应训练:通过定义基于耦合超球体的损失函数,训练补丁描述符,使正常特征在记忆特征周围密集分布。具体来说,通过最小化损失函数,将补丁特征嵌入到以记忆特征为中心的超球体内。
-
异常检测与定位:在测试阶段,将测试样本的补丁特征与内存库中的最近邻记忆特征进行匹配,生成表示异常程度的热图。最后,通过特定的评分函数计算热图中的异常定位评分图。
算法优势
-
特征自适应:CFA通过迁移学习和特征自适应,解决了预训练CNN的偏差问题,使模型能够更好地适应目标数据集。
-
高效的内存库:CFA的内存库与目标数据集的大小无关,显著降低了内存占用,同时保持了高性能。
-
优异的性能:在MVTec AD基准测试中,CFA在图像级异常检测和像素级异常定位上分别达到了99.5%和98.5%的AUROC评分。
应用场景
CFA算法适用于需要精确异常定位的工业场景,例如产品质量检测、缺陷定位等。其高效的特征自适应能力和低内存占用使其在实际应用中具有显著优势。
环境配置
-
Python:推荐使用Python 3.8或更高版本。
-
深度学习框架:CFA算法通常基于PyTorch实现,因此需要安装PyTorch。根据你的CUDA版本选择合适的PyTorch版本。
-
pytorch官网PyTorch
-
其他依赖库:安装以下常用库:
pip install numpy matplotlib opencv-python scikit-image
数据集格式
因为该算法是异常检测,只需要正常数据集和异常数据集以及异常类别即可,不需要标注工作。https://www.mvtec.com/company/research/datasets/mvtec-ad(这个数据集为MVTec AD(Anomaly Detection))该数据集包含多种工业产品的正常样本和异常样本,适用于异常检测任务。如下图为数据集格式
代码准备
如果有现成的开源代码,可以直接从GitHub等平台下载。例如,搜索CFA算法的开源实现。
如果没有现成的代码,可以参考CFA算法的论文,自行实现。
https://arxiv.org/abs/2206.04325 论文链接
建议可以把这个论文下载,给通义或者其他大模型,让它按照传统的模型搭建流程进行搭建自己的网络结构。这个可以帮助我们快速复现
关键模块实现
以下是CFA算法的关键模块实现思路:
特征提取模块
基于预训练的CNN(如ResNet)提取特征:
import torch
import torchvision.models as models
class FeatureExtractor(torch.nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
self.cnn = models.resnet18(pretrained=True)
self.cnn = torch.nn.Sequential(*list(self.cnn.children())[:-2]) # 去掉最后两层
def forward(self, x):
return self.cnn(x)
特征自适应模块
实现耦合超球体特征自适应:
import torch.nn.functional as F
def coupled_hypersphere_loss(features, memory_bank):
# 计算特征与内存库中记忆特征的距离
distances = torch.cdist(features, memory_bank)
# 最小化距离,使特征靠近记忆特征
loss = torch.mean(torch.min(distances, dim=1)[0])
return loss
内存库模块
管理内存库,存储正常样本的特征:
class MemoryBank:
def __init__(self, capacity):
self.capacity = capacity
self.bank = torch.randn(capacity, feature_dim) # 初始化内存库
def update(self, features):
# 更新内存库
self.bank = torch.cat((self.bank, features), dim=0)
self.bank = self.bank[-self.capacity:] # 保持内存库大小不变
训练与测试
训练时优化特征自适应模块,测试时计算异常分数:
# 训练
def train(model, memory_bank, dataloader, optimizer):
model.train()
for images, _ in dataloader:
features = model(images)
loss = coupled_hypersphere_loss(features, memory_bank.bank)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试
def test(model, memory_bank, dataloader):
model.eval()
anomaly_scores = []
with torch.no_grad():
for images, _ in dataloader:
features = model(images)
distances = torch.cdist(features, memory_bank.bank)
anomaly_scores.append(torch.min(distances, dim=1)[0])
return torch.cat(anomaly_scores)
模型训练
配置参数
在 config.py
中设置训练参数:
class Config:
batch_size = 32
learning_rate = 1e-4
epochs = 50
memory_bank_capacity = 10000
feature_dim = 512
训练过程
运行 train.py
:
from models.cfa_model import FeatureExtractor, MemoryBank
from utils.data_loader import get_dataloader
from config import Config
if __name__ == "__main__":
config = Config()
model = FeatureExtractor()
memory_bank = MemoryBank(config.memory_bank_capacity)
dataloader = get_dataloader(config.batch_size)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
for epoch in range(config.epochs):
train(model, memory_bank, dataloader, optimizer)
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
模型测试
运行 test.py
:
from models.cfa_model import FeatureExtractor, MemoryBank
from utils.data_loader import get_test_dataloader
from config import Config
if __name__ == "__main__":
config = Config()
model = FeatureExtractor()
memory_bank = MemoryBank(config.memory_bank_capacity)
test_dataloader = get_test_dataloader(config.batch_size)
anomaly_scores = test(model, memory_bank, test_dataloader)
# 保存异常分数
torch.save(anomaly_scores, "anomaly_scores.pth")
结果评估
使用AUROC(Area Under the ROC Curve)等指标评估异常检测性能:
from sklearn.metrics import roc_auc_score
import numpy as np
# 假设 ground_truth 是真实的异常标签
ground_truth = np.load("ground_truth.npy")
anomaly_scores = torch.load("anomaly_scores.pth").numpy()
auroc = roc_auc_score(ground_truth, anomaly_scores)
print(f"AUROC: {auroc:.4f}")