NeurIPS 2024 | Mamba杀入异常检测!MambaAD:第一个使用Mamba进行多类无监督异常检测...

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

点击进入—>【Mamba和异常检测】交流群

添加微信号:CVer111,小助手会拉你进群!

扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!

4fa520a57016708f1f44db7d9a7bfe30.png

编辑:CVer 微信公众号| 作者:CVer粉丝投稿

a3a2bd40b420da3dc7e7d223a2fb8d99.png

(来自浙大,腾讯优图,南洋理工大学)

论文:https://arxiv.org/abs/2404.06564

主页:lewandofskee.github.io/projects/MambaAD/

代码:github.com/lewandofskee/MambaAD

内容总结(太长不看版)

过去基于CNN和Transformer的结构的算法被广泛应用于异常检测领域取得了一定的效果。但是CNN无法处理长距离信息的关联性,而Transformer受限于其平方级别的计算复杂度。最近基于Mamba的模型凭借着其长距离建模的出色能力与线性计算复杂度引起了广泛的研究。在本文中,我们首次将Mamba应用于多类无监督异常检测领域并提出MambaAD如图1所示包含一个预训练的CNN编码器和由不同尺度的局部增强状态空间(LSS)模块成的Mamba解码器。被提出的LSS 模块包含并行的连续混合状态空间(HSS)模块和多核的卷积操作,能够同时具有长距离的建模能力与局部信息的建模能力。HSS模块所包含的混合扫描(HS)编码器将特征图编码至5种不同的扫描方式和8种不同的扫描方向并输送至SSM中建立全局的联系。其中我们采用了Hilbert扫描方式和8种扫描方向有利于提升特征序列的建模能力。大量实验表明,我们在6种不同异常检测数据集上、7种不同的指标上取得了SoTA,证明了Mamba AD方法了有效性。   

d4b1373a818a772be59ebccc74c73f25.png

图1 MambaAD框架结构图

1、引言

尽管基于合成和基于特征编码的方法在AD领域取得了不错的效果,但是这些方法需要额外的设计与不可轻易扩展的框架。基于重建的方法如RD4AD和UniAD具有非常好的效果与较好的可扩展性。RD4AD提出了预训练教师模型与学生模型并在多尺度特征层面上进行异常值对比。尽管基于CNN的RD4AD在多尺度上的有着捕捉相邻的信息能力并取得了较好的性能,但是其无法建立长距离的相关性。首个多类异常检测算法UniAD是基于预训练的CNN编码器和Transformer解码器的架构。尽管Transformer有着全局建模的能力,但是由于其平方级别计算复杂度,UniAD仅在最小尺度的特征图上对比得到异常图,这无疑会减弱模型性能。

d37c51d1dfba3703d8a053bd8ee2d67c.png    

图2 MambaAD与基于CNN的RD4AD和Transformer的UniAD方法对比

2、MambaAD贡献

最近,Mamba在大语言模型中取得了出色的表现,有着远小于transformer的线性复杂度并且能够与transformer相媲美的效果。近期大量的工作将Mamba引入视觉领域,涌现大量基于Mamba的研究工作。本文首次将Mamba引入异常检测领域,构建了MambaAD架构有着全局+局部的建模能力,并且利用它的线性复杂度在多尺度上计算异常图并且有着较低的参数量与计算复杂度。具体来说MambaAD使用金字塔结构的自编码器结构来重建多尺度的特征,通过一个预训练编码器和提出的基于Mamba结构的解码器。其中基于Mamba结构的解码器由不同尺度与数量的局部增强状态空间(Locality-Enhanced State Space) LSS module组成。LSS module由两部分组成:连续的(Hybrid State Space) HSS模块用于捕捉全局的信息和并联多核的卷积操作用来建立局部的联系。最终的输出特征不仅包含基于Mamba的全局建模能力,还包含了基于CNN的建立局部相关性的能力。所提出的HSS模块探索了5种不同的扫描方式和8种不同的扫描方向,其中的HS编码器和解码器分别将特征图编码至不同的扫描方式和方向的序列并解码。HSS模块能够增强在多个方向上的全局感受野并且所使用的Hilbert的扫描方式也更加适用于工业产品位置集中在中心区域的特点。通过在不同尺度的特征图上计算异常图并相加,MambaAD在6个不同的异常检测数据集上取得了SoTA性能并且模型参数量与计算复杂度也非常低。具体来说,我们的贡献如下:

1)我们提出了MambaAD首次使用Mamba来解决多类无监督异常检测任务,它能够在很少的模型参数量和计算复杂度上进行多尺度训练与推理。

2)我们设计了一个LSS module,连接的HSS模块和并行的多内核卷积分别提取全局特征相关性与局部信息关联性,实现全局加局部的统一建模。

3)我们探索了HSS模块即5种方法8种多方向的混合扫描方式如图3所示,能够增强复杂的异常检测图像在不同类别不同形态下的全局建模能力。   

4)我们证明了MambaAD在多类异常检测任务的优越性和高效性。在6个不同的异常检测数据集上达到SOTA如表1所示,并且有着非常低的模型参数与计算复杂度见表2。

cb35e8ad3e360f86012dd492431fbd77.png

图3 五种不同的扫描方式和八种扫描方向  

8de301e632419f332f3f4dba12804010.png

表1 三个异常检测数据集对比结果,更多详细结果参考文章附录        

    

8ba3621e5fb0bcef8028a37f6dec2d13.png

表2 MambaAD与SOTA方法在参数量、计算复杂度和效果上对比

bf04c77ed5b06b732dd4112a339dd8f0.png

图4 MambaAD与SOTA方法定性实验结果对比

 
 

何恺明在MIT授课的课件PPT下载

 
 

在CVer公众号后台回复:何恺明,即可下载本课程的所有566页课件PPT!赶紧学起来!

ECCV 2024 论文和代码下载

在CVer公众号后台回复:ECCV2024,即可下载ECCV 2024论文和代码开源的论文合集

CVPR 2024 论文和代码下载

在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集

Mamba、多模态和扩散模型交流群成立

 
 
扫描下方二维码,或者添加微信号:CVer111,即可添加CVer小助手微信,便可申请加入CVer-Mamba、多模态学习或者扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF、3DGS、Mamba等。
一定要备注:研究方向+地点+学校/公司+昵称(如Mamba、多模态学习或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

 
 
▲扫码或加微信号: CVer111,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集上万人!

▲扫码加入星球学习
 
 
▲点击上方卡片,关注CVer公众号
整理不易,请赞和在看
### 使用 MambaAD 训练自定义数据集 为了使用 MambaAD 方法训练自定义的数据集,需遵循特定的流程来准备和处理数据。MambaAD 是一种用于多类无监督异常检测的方法,其核心组件包括一个预训练的 CNN 编码器以及由多个不同尺度的局部增强状态空间 (LSS) 模块组成的 Mamba 解码器[^1]。 #### 数据预处理 在开始训练之前,确保输入图像已经过适当预处理。这通常涉及标准化操作,使得每张图片具有零均值和单位方差。对于彩色图像而言,还需要调整大小至统一尺寸以便于批量处理。 ```python from torchvision import transforms transform = transforms.Compose([ transforms.Resize((256, 256)), # 调整分辨率 transforms.ToTensor(), # 转换为 Tensor 类型并归一化到 [0,1] ]) ``` #### 准备 PyTorch Dataset 和 DataLoader 创建继承 `torch.utils.data.Dataset` 的子类以加载本地文件夹内的所有样本,并实现必要的接口函数如 `__len__()`, `__getitem__()`. 接着通过 `DataLoader` 来管理批次读取过程。 ```python import os from PIL import Image from torch.utils.data import Dataset, DataLoader class CustomImageDataset(Dataset): def __init__(self, img_dir, transform=None): self.img_labels = [] self.transform = transform files = sorted(os.listdir(img_dir)) for file_name in files: label = int('abnormal' not in file_name.lower()) path = os.path.join(img_dir, file_name) self.img_labels.append([path,label]) def __len__(self): return len(self.img_labels) def __getitem__(self, idx): image_path, label = self.img_labels[idx] image = Image.open(image_path).convert("RGB") if self.transform is not None: image = self.transform(image) sample = {"image": image, "label": label} return sample train_dataset = CustomImageDataset("./data/train", transform=transform) test_dataset = CustomImageDataset("./data/test", transform=transform) batch_size = 32 num_workers = 4 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_loader = DataLoader(test_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=num_workers) ``` #### 加载预训练模型与微调 由于 MambaAD 已经包含了经过充分训练过的 CNN 编码器,在实际应用中可以直接下载官方发布的权重文件来进行迁移学习。如果希望进一步优化性能,则可以在目标域内继续迭代更新参数直至收敛为止。 ```python import mambaad # 假设这是安装好的库名 model = mambnad.MambaAD(pretrained=True) if use_cuda: model.cuda() optimizer = optim.Adam(model.parameters(), lr=learning_rate) scheduler = StepLR(optimizer, step_size=7, gamma=0.1) for epoch in range(num_epochs): train_one_epoch(model, optimizer, scheduler, train_loader) def train_one_epoch(model, optimizer, scheduler, data_loader): ... ``` 上述代码片段展示了如何构建适合 MambaAD 输入格式的数据管道结构;同时也介绍了获取预训练版本的方式及其后续可能采取的学习策略。请注意具体细节会依据所选框架而有所差异,请参照相应文档完成剩余部分开发工作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值